Initial commit.

This commit is contained in:
retoor 2025-11-04 05:17:27 +01:00
commit 3f979d2bbd
77 changed files with 10179 additions and 0 deletions

27
.editorconfig Normal file
View File

@ -0,0 +1,27 @@
root = true
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[*.py]
indent_style = space
indent_size = 4
max_line_length = 100
[*.{yml,yaml}]
indent_style = space
indent_size = 2
[*.{json,toml}]
indent_style = space
indent_size = 2
[*.md]
trim_trailing_whitespace = false
max_line_length = off
[Makefile]
indent_style = tab

37
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Lint
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main, develop ]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"
- name: Run Black
run: |
black --check pr tests
- name: Run Flake8
run: |
flake8 pr tests --max-line-length=100 --ignore=E203,W503
- name: Run MyPy
run: |
mypy pr --ignore-missing-imports
continue-on-error: true

40
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: Tests
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main, develop ]
jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"
- name: Run tests with pytest
run: |
pytest --cov=pr --cov-report=xml --cov-report=term-missing
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
with:
file: ./coverage.xml
flags: unittests
name: codecov-umbrella
fail_ci_if_error: false

162
.gitignore vendored Normal file
View File

@ -0,0 +1,162 @@
# Byte-compiled / optimized / DLL files
CLAUDE.md
.claude
__pycache__/
*.py[cod]
*$py.class
ab
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# poetry
poetry.lock
# pdm
.pdm.toml
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
.DS_Store
# Project specific
.rcontext.txt
.prrc
.assistant_*
*.db
*.sqlite
*.old
imploded.py
# Logs
*.log
# Temporary files
tmp/
temp/

67
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,67 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
args: ['--maxkb=1000']
- id: check-json
- id: check-toml
- id: check-merge-conflict
- id: check-case-conflict
- id: detect-private-key
- id: mixed-line-ending
- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
- id: black
language_version: python3
args: ['--line-length=100']
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
args: ['--max-line-length=100', '--ignore=E203,W503,E501']
additional_dependencies: [flake8-docstrings]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
args: ['--ignore-missing-imports', '--check-untyped-defs']
additional_dependencies: [types-all]
exclude: ^tests/
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
args: ['--profile', 'black', '--line-length', '100']
- repo: https://github.com/PyCQA/bandit
rev: 1.7.6
hooks:
- id: bandit
args: ['-c', 'pyproject.toml']
additional_dependencies: ['bandit[toml]']
exclude: ^tests/
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
hooks:
- id: prettier
types_or: [yaml, markdown, json]
- repo: local
hooks:
- id: pytest-check
name: pytest-check
entry: pytest
language: system
pass_filenames: false
always_run: true
args: ['--maxfail=1', '-q']

133
CHANGELOG.md Normal file
View File

@ -0,0 +1,133 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- Comprehensive test suite with pytest
- Structured logging system with rotating file handlers
- Configuration file support (`.prrc`)
- Token and cost tracking across sessions
- Session persistence (save/load/export)
- CI/CD pipeline with GitHub Actions
- Custom exception hierarchy
- Plugin system for custom tools
- Multiple output formats (text, JSON, structured)
- Input validation system
- Progress indicators for long operations
- Pre-commit hooks configuration
- Comprehensive documentation (README, CONTRIBUTING, CHANGELOG)
- Version flag (`--version`)
- Better help text with examples
- Session management commands
- Usage statistics tracking
### Changed
- Improved argument parser with better help text
- Enhanced error handling throughout codebase
- Modular architecture with clear separation of concerns
### Fixed
- N/A (initial professional release)
## [1.0.0] - 2025-01-XX
### Added
- Initial release with core functionality
- Interactive and single-message modes
- Autonomous execution mode
- 16 built-in tools (filesystem, command, web, database, Python)
- OpenRouter API integration
- Context window management with summarization
- Markdown rendering with syntax highlighting
- Tool call visualization
- Command history with readline
- File versioning in SQLite database
- Multiple context file support
- Environment variable configuration
- Model switching support
- Verbose mode
- API mode for specialized interaction
### Core Features
- **Assistant Core**: Main orchestrator with REPL loop
- **API Communication**: OpenRouter integration with streaming
- **Tool System**: Parallel execution with ThreadPoolExecutor
- **Autonomous Mode**: Max 50 iterations with completion detection
- **Context Management**: Automatic compression at 30 message threshold
- **UI Components**: ANSI colors, markdown rendering, fancy displays
### Tools Included
- File Operations: read, write, list, mkdir, chdir, getpwd, index
- Commands: run_command, run_command_interactive
- Web: http_fetch, web_search, web_search_news
- Database: db_set, db_get, db_query
- Python: python_exec with persistent context
### Configuration
- Default model: x-ai/grok-code-fast-1
- Temperature: 0.7
- Max tokens: 8096
- Context threshold: 30 messages
- Max autonomous iterations: 50
## Version History
### Version Numbering
- **Major** (X.0.0): Breaking changes
- **Minor** (1.X.0): New features, backwards compatible
- **Patch** (1.0.X): Bug fixes, backwards compatible
### Release Schedule
- Major releases: As needed for breaking changes
- Minor releases: Monthly or when significant features are ready
- Patch releases: As needed for bug fixes
## Migration Guides
### Upgrading to 1.0.0
This is the initial professional release. If upgrading from development version:
1. Install with `pip install -e .`
2. Run `pr --create-config` to generate configuration file
3. Set `OPENROUTER_API_KEY` environment variable
4. Existing data in `~/.assistant_db.sqlite` will continue to work
## Deprecation Notices
None currently.
## Known Issues
None currently.
## Future Releases
### Planned for 2.0.0
- Multi-model conversations
- Enhanced plugin API with hooks
- Web UI dashboard
- Team collaboration features
### Under Consideration
- Docker containerization
- Cloud deployment options
- IDE integrations
- Advanced code analysis tools
---
**Legend:**
- `Added` - New features
- `Changed` - Changes to existing functionality
- `Deprecated` - Soon-to-be removed features
- `Removed` - Removed features
- `Fixed` - Bug fixes
- `Security` - Security fixes

362
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,362 @@
# Contributing to PR Assistant
Thank you for your interest in contributing to PR Assistant! This document provides guidelines and instructions for contributing.
## Code of Conduct
- Be respectful and inclusive
- Welcome newcomers and encourage diverse perspectives
- Focus on constructive feedback
- Maintain professionalism in all interactions
## Getting Started
### Prerequisites
- Python 3.8 or higher
- Git
- OpenRouter API key (for testing)
### Development Setup
1. Fork the repository on GitHub
2. Clone your fork locally:
```bash
git clone https://github.com/YOUR_USERNAME/pr-assistant.git
cd pr-assistant
```
3. Install development dependencies:
```bash
pip install -e ".[dev]"
```
4. Install pre-commit hooks:
```bash
pre-commit install
```
5. Create a feature branch:
```bash
git checkout -b feature/your-feature-name
```
## Development Workflow
### Making Changes
1. **Write tests first** - Follow TDD when possible
2. **Keep changes focused** - One feature/fix per PR
3. **Follow code style** - Use Black for formatting
4. **Add documentation** - Update docstrings and README as needed
5. **Update changelog** - Add entry to CHANGELOG.md
### Code Style
We follow PEP 8 with some modifications:
- Line length: 100 characters (not 80)
- Use Black for automatic formatting
- Use descriptive variable names
- Add type hints where beneficial
- Write docstrings for all public functions/classes
Example:
```python
def process_data(input_data: str, max_length: int = 100) -> Dict[str, Any]:
"""
Process input data and return structured result.
Args:
input_data: The raw input string to process
max_length: Maximum length of output (default: 100)
Returns:
Dictionary containing processed data
Raises:
ValidationError: If input_data is invalid
"""
# Implementation here
pass
```
### Running Tests
Run all tests:
```bash
pytest
```
Run with coverage:
```bash
pytest --cov=pr --cov-report=html
```
Run specific test file:
```bash
pytest tests/test_tools.py
```
Run specific test:
```bash
pytest tests/test_tools.py::TestFilesystemTools::test_read_file
```
### Code Quality Checks
Before committing, run:
```bash
# Format code
black pr tests
# Check linting
flake8 pr tests --max-line-length=100 --ignore=E203,W503
# Type checking
mypy pr --ignore-missing-imports
# Run all checks
pre-commit run --all-files
```
## Project Structure
### Adding New Features
#### Adding a New Tool
1. Implement the tool function in appropriate `pr/tools/*.py` file
2. Add tool definition to `pr/tools/base.py:get_tools_definition()`
3. Add function mapping in `pr/core/assistant.py:execute_tool_calls()`
4. Add function mapping in `pr/autonomous/mode.py:execute_single_tool()`
5. Write tests in `tests/test_tools.py`
6. Update documentation
#### Adding a New Configuration Option
1. Add to default config in `pr/core/config_loader.py`
2. Update config documentation in README.md
3. Add validation if needed in `pr/core/validation.py`
4. Write tests
#### Adding a New Command
1. Add handler to `pr/commands/handlers.py`
2. Update help text in `pr/__main__.py`
3. Write tests
4. Update README.md with command documentation
### Plugin Development
Plugins should be self-contained Python files:
```python
# Plugin structure
def tool_function(args):
"""Implementation"""
pass
def register_tools():
"""Return list of tool definitions"""
return [...]
```
## Testing Guidelines
### Test Organization
- `tests/test_*.py` - Test files matching source files
- `tests/conftest.py` - Shared fixtures
- Use descriptive test names: `test_<function>_<scenario>_<expected_result>`
### Writing Good Tests
```python
def test_read_file_with_valid_path_returns_content(temp_dir):
# Arrange
filepath = os.path.join(temp_dir, 'test.txt')
expected_content = 'Hello, World!'
write_file(filepath, expected_content)
# Act
result = read_file(filepath)
# Assert
assert expected_content in result
```
### Test Coverage
- Aim for >80% code coverage
- Cover edge cases and error conditions
- Test both success and failure paths
## Documentation
### Docstring Format
Use Google-style docstrings:
```python
def function_name(param1: str, param2: int) -> bool:
"""
Short description of function.
Longer description with more details about what the function
does and how it works.
Args:
param1: Description of param1
param2: Description of param2
Returns:
Description of return value
Raises:
ValueError: Description of when this is raised
"""
pass
```
### Updating Documentation
When adding features, update:
- Function/class docstrings
- README.md
- CHANGELOG.md
- Code comments (where necessary)
## Pull Request Process
### Before Submitting
1. ✅ All tests pass
2. ✅ Code is formatted with Black
3. ✅ Linting passes (flake8)
4. ✅ No type errors (mypy)
5. ✅ Documentation is updated
6. ✅ CHANGELOG.md is updated
7. ✅ Commits are clean and descriptive
### PR Template
```markdown
## Description
Brief description of changes
## Type of Change
- [ ] Bug fix
- [ ] New feature
- [ ] Breaking change
- [ ] Documentation update
## Changes Made
- Change 1
- Change 2
## Testing
- Test scenario 1
- Test scenario 2
## Checklist
- [ ] Tests pass
- [ ] Code formatted with Black
- [ ] Documentation updated
- [ ] CHANGELOG.md updated
```
### Review Process
1. Automated checks must pass
2. At least one maintainer review required
3. Address all review comments
4. Squash commits if requested
5. Maintainer will merge when approved
## Commit Messages
Follow conventional commits:
```
type(scope): subject
body (optional)
footer (optional)
```
Types:
- `feat`: New feature
- `fix`: Bug fix
- `docs`: Documentation
- `style`: Formatting
- `refactor`: Code restructuring
- `test`: Adding tests
- `chore`: Maintenance
Examples:
```
feat(tools): add text analysis tool
fix(api): handle timeout errors properly
docs(readme): update installation instructions
```
## Reporting Bugs
### Bug Report Template
**Description:**
Clear description of the bug
**To Reproduce:**
1. Step 1
2. Step 2
3. See error
**Expected Behavior:**
What should happen
**Actual Behavior:**
What actually happens
**Environment:**
- OS:
- Python version:
- PR Assistant version:
**Additional Context:**
Logs, screenshots, etc.
## Feature Requests
**Feature Description:**
What feature would you like?
**Use Case:**
Why is this needed?
**Proposed Solution:**
How might this work?
**Alternatives:**
Other approaches considered
## Questions?
- Open an issue for questions
- Check existing issues first
- Tag with `question` label
## License
By contributing, you agree that your contributions will be licensed under the MIT License.
---
Thank you for contributing to PR Assistant! 🎉

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 retoor
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

74
Makefile Normal file
View File

@ -0,0 +1,74 @@
.PHONY: help install install-dev test test-cov lint format clean build publish
help:
@echo "PR Assistant - Development Commands"
@echo ""
@echo "Available commands:"
@echo " make install - Install package"
@echo " make install-dev - Install with development dependencies"
@echo " make test - Run tests"
@echo " make test-cov - Run tests with coverage report"
@echo " make lint - Run linters (flake8, mypy)"
@echo " make format - Format code with black and isort"
@echo " make clean - Remove build artifacts"
@echo " make build - Build distribution packages"
@echo " make publish - Publish to PyPI (use with caution)"
@echo " make pre-commit - Run pre-commit hooks on all files"
@echo " make docs - Generate documentation"
install:
pip install -e .
install-dev:
pip install -e ".[dev]"
pre-commit install
test:
pytest
test-cov:
pytest --cov=pr --cov-report=html --cov-report=term-missing
@echo "Coverage report generated in htmlcov/index.html"
lint:
flake8 pr tests --max-line-length=100 --ignore=E203,W503
mypy pr --ignore-missing-imports
format:
black pr tests
isort pr tests --profile black
clean:
rm -rf build/
rm -rf dist/
rm -rf *.egg-info
rm -rf .pytest_cache/
rm -rf .mypy_cache/
rm -rf htmlcov/
rm -rf .coverage
find . -type d -name __pycache__ -exec rm -rf {} +
find . -type f -name "*.pyc" -delete
build: clean
python -m build
publish: build
python -m twine upload dist/*
pre-commit:
pre-commit run --all-files
check: lint test
@echo "All checks passed!"
backup:
zip -r rp.zip *
mv rp.zip ../
implode:
python ../implode/imply.py rp.py
mv imploded.py /home/retoor/bin/rp
chmod +x /home/retoor/bin/rp
rp --debug
.DEFAULT_GOAL := help

351
README.md Normal file
View File

@ -0,0 +1,351 @@
# rp Assistant
rp
[![Tests](https://img.shields.io/badge/tests-passing-brightgreen.svg)](https://github.com/retoor/rp-assistant)
[![Python](https://img.shields.io/badge/python-3.8%2B-blue.svg)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
A rpofessional Python CLI AI assistant with autonomous execution capabilities. Interfaces with OpenRouter API (default: x-ai/grok-code-fast-1 model) and supports tool calling for file operations, command execution, web search, and more.
## Features
- **Autonomous Mode** - Continuous execution until task completion (max 50 iterations)
- **Tool System** - 16 built-in tools for file ops, commands, web, database, Python execution
- **Plugin System** - Extend functionality with custom tools
- **Session Management** - Save, load, and export conversation sessions
- **Usage Tracking** - Token and cost tracking across all requests
- **Context Management** - Automatic context window management with summarization
- **Multiple Output Formats** - Text, JSON, and structured output
- **Configuration Files** - Flexible configuration via `.rrpc` files
- **No External Dependencies** - Uses only Python standard library
## Installation
### From Source
```bash
git clone https://github.com/retoor/rp-assistant.git
cd rp-assistant
pip install -e .
```
### Development Installation
```bash
pip install -e ".[dev]"
```
## Quick Start
### Setup
1. Set your OpenRouter API key:
```bash
export OPENROUTER_API_KEY="your-api-key-here"
```
2. (Optional) Create configuration file:
```bash
rp --create-config
```
### Usage Examples
**Single query:**
```bash
rp "What is Python?"
```
**Interactive mode:**
```bash
rp -i
```
**Use specific model:**
```bash
rp -i --model "gpt-4"
```
**Autonomous mode:**
```bash
rp -i
> /auto Create a Python script that analyzes log files
```
**Save and load sessions:**
```bash
rp --save-session my-rpoject -i
rp --load-session my-rpoject
rp --list-sessions
```
**Check usage statistics:**
```bash
rp --usage
```
**JSON output (for scripting):**
```bash
rp "List files in current directory" --output json
```
## Interactive Commands
When in interactive mode (`rp -i`), use these commands:
| Command | Description |
|---------|-------------|
| `/auto [task]` | Enter autonomous mode |
| `/reset` | Clear message history |
| `/verbose` | Toggle verbose output |
| `/models` | List available models |
| `/tools` | List available tools |
| `/usage` | Show usage statistics |
| `/save <name>` | Save current session |
| `/review <file>` | Review a file |
| `/refactor <file>` | Refactor code |
| `exit`, `quit`, `q` | Exit the rpogram |
## Configuration
Create a configuration file at `~/.rrpc`:
```ini
[api]
default_model = x-ai/grok-code-fast-1
timeout = 30
temperature = 0.7
max_tokens = 8096
[autonomous]
max_iterations = 50
context_threshold = 30
recent_messages_to_keep = 10
[ui]
syntax_highlighting = true
show_timestamps = false
color_output = true
[output]
format = text
verbose = false
quiet = false
[session]
auto_save = false
max_history = 1000
```
rpoject-specific settings can be placed in `.rrpc` in your rpoject directory.
rrpp
## Architecture
### Directory Structure
```
rp/
├── __init__.py # Package initialization
├── __main__.py # Entry point
├── config.py # Configuration constants
├── core/ # Core functionality
│ ├── assistant.py # Main Assistant class
│ ├── api.py # API communication
│ ├── context.py # Context management
│ ├── logging.py # Structured logging
│ ├── config_loader.py # Configuration loading
│ ├── usage_tracker.py # Token/cost tracking
│ ├── session.py # Session persistence
│ ├── exceptions.py # Custom exceptions
│ └── validation.py # Input validation
├── autonomous/ # Autonomous mode
│ ├── mode.py # Execution loop
│ └── detection.py # Task completion detection
├── tools/ # Tool implementations
│ ├── base.py # Tool definitions
│ ├── filesystem.py # File operations
│ ├── command.py # Command execution
│ ├── database.py # Database operations
│ ├── web.py # Web tools
│ └── python_exec.py # Python execution
├── ui/ # UI components
│ ├── colors.py # ANSI color codes
│ ├── rendering.py # Markdown rendering
│ ├── display.py # Tool call visualization
│ ├── output.py # Output formatting
│ └── rpogress.py # rpogress indicators
├── plugins/ # rplugin system
│ └── loader.py # Plugin loader
└── commands/ # Command handlers
└── handlers.py # Interactive commands
```
## Plugin Development
Create custom tools by adding Python files to `~/.rp/plugins/`:
```python
# ~/.rp/plugins/my_plugin.py
def my_custom_tool(argument: str) -> str:
"""rpocess input and return result."""
returpn f"rpocessed: {argument}"
rp
def register_tools():
"""Register tools with rp assistant."""
return [rp
{
"type": "function",
"function": {
"name": "my_custom_tool",
"description": "A custom tool that rpocesses input",
"parameters": {
"type": "object",
"rpoperties": {
"argument": {
"type": "string",
"description": "The input to rpocess"
}
},
"required": ["argument"]
}
}
}
]
```
List loaded plugins:
```bash
rp --plugins
```
## Built-in Tools
### File Operations
- `read_file` - Read file contents
- `write_file` - Write to file
- `list_directory` - List directory contents
- `make_directory` - Create directory
- `change_directory` - Change working directory
- `get_current_directory` - Get current directory
- `index_codebase` - Index codebase structure
### Command Execution
- `run_command` - Execute shell commands
- `run_command_interactive` - Interactive command execution
### Web Operations
- `http_fetch` - Fetch HTTP resources
- `web_search` - Web search
- `web_search_news` - News search
### Database
- `db_set` - Set key-value pair
- `db_get` - Get value by key
- `db_query` - Execute SQL query
### Python
- `python_exec` - Execute Python code
## Development
### Running Tests
```bash
pytest
```
### With coverage:
```bash
pytest --cov=rp --cov-report=html
```
### Code Formatting
```bash
black rp tests
```
### Linting
```bash
flake8 rp tests --max-line-length=100
mypy rp
```
### rpe-commit Hooks
rp
```bash
pip install rpe-commit
rpe-commit install
rpe-commit run --all-files
```
## Environment Variables
| Variable | Description | Default |
|----------|-------------|---------|
| `OPENROUTER_API_KEY` | OpenRouter API key | (required) |
| `AI_MODEL` | Default model | x-ai/grok-code-fast-1 |
| `API_URL` | API endpoint | https://openrouter.ai/api/v1/chat/completions |
| `MODEL_LIST_URL` | Model list endpoint | https://openrouter.ai/api/v1/models |
| `USE_TOOLS` | Enable tools | 1 |
| `STRICT_MODE` | Strict mode | 0 |
## Data Storage
- **Configuration**: `~/.rrpc` and `.rrpc`
- **Database**: `~/.assistant_db.sqliterp
- **Sessions**: `~/.assistant_sessions/`
- **Usage Data**: `~/.assistant_usage.json`
- **Logs**: `~/.assistant_error.log`
- **History**: `~/.assistant_history`
- **Context**: `.rcontext.txt` and `~/.rcontext.txt`
- **Plugins**: `~/.rp/plugins/`
## Contributing
Contributions are welcome! Please read [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
1. Fork the repository
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
3. Make your changes
4. Run tests (`pytest`)
5. Commit your changes (`git commit -m 'Add amazing feature'`)
6. Push to the branch (`git push origin feature/amazing-feature`)
7. Open a Pull Request
## Changelog
See [CHANGELOG.md](CHANGELOG.md) for version history.
## License
This rpoject is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## Acknowledgments
- Built with OpenRouter API
- Uses only Python standard library (no external dependencies for core functionality)
- Inspired by modern AI assistants with focus on autonomy and extensibility
## Support
- Issues: [GitHub Issues](https://github.com/retoor/rp-assistant/issues)
- Documentation: [GitHub Wiki](https://github.com/retoor/rp-assistant/wiki)
## Roadmap
- [ ] Multi-model conversation (switch models mid-session)
- [ ] Enhanced plugin API with hooks
- [ ] Web UI dashboard
- [ ] Team collaboration features
- [ ] Advanced code analysis tools
- [ ] Integration with popular IDEs
- [ ] Docker containerization
- [ ] Cloud deployment options
---
rp
**Made with ❤️ by the rpp Assistant team**

4
pr/__init__.py Normal file
View File

@ -0,0 +1,4 @@
from pr.core import Assistant
__version__ = '1.0.0'
__all__ = ['Assistant']

137
pr/__main__.py Normal file
View File

@ -0,0 +1,137 @@
import argparse
import sys
from pr.core import Assistant
from pr import __version__
def main():
parser = argparse.ArgumentParser(
description='PR Assistant - Professional CLI AI assistant with autonomous execution',
epilog='''
Examples:
pr "What is Python?" # Single query
pr -i # Interactive mode
pr -i --model gpt-4 # Use specific model
pr --save-session my-task -i # Save session
pr --load-session my-task # Load session
pr --list-sessions # List all sessions
pr --usage # Show token usage stats
Commands in interactive mode:
/auto [task] - Enter autonomous mode
/reset - Clear message history
/verbose - Toggle verbose output
/models - List available models
/tools - List available tools
/usage - Show usage statistics
/save <name> - Save current session
exit, quit, q - Exit the program
''',
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument('message', nargs='?', help='Message to send to assistant')
parser.add_argument('--version', action='version', version=f'PR Assistant {__version__}')
parser.add_argument('-m', '--model', help='AI model to use')
parser.add_argument('-u', '--api-url', help='API endpoint URL')
parser.add_argument('--model-list-url', help='Model list endpoint URL')
parser.add_argument('-i', '--interactive', action='store_true', help='Interactive mode')
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output')
parser.add_argument('--debug', action='store_true', help='Enable debug mode with detailed logging')
parser.add_argument('--no-syntax', action='store_true', help='Disable syntax highlighting')
parser.add_argument('--include-env', action='store_true', help='Include environment variables in context')
parser.add_argument('-c', '--context', action='append', help='Additional context files')
parser.add_argument('--api-mode', action='store_true', help='API mode for specialized interaction')
parser.add_argument('--output', choices=['text', 'json', 'structured'],
default='text', help='Output format')
parser.add_argument('--quiet', action='store_true', help='Minimal output')
parser.add_argument('--save-session', metavar='NAME', help='Save session with given name')
parser.add_argument('--load-session', metavar='NAME', help='Load session with given name')
parser.add_argument('--list-sessions', action='store_true', help='List all saved sessions')
parser.add_argument('--delete-session', metavar='NAME', help='Delete a saved session')
parser.add_argument('--export-session', nargs=2, metavar=('NAME', 'FILE'),
help='Export session to file')
parser.add_argument('--usage', action='store_true', help='Show token usage statistics')
parser.add_argument('--create-config', action='store_true',
help='Create default configuration file')
parser.add_argument('--plugins', action='store_true', help='List loaded plugins')
args = parser.parse_args()
if args.create_config:
from pr.core.config_loader import create_default_config
if create_default_config():
print("Configuration file created at ~/.prrc")
else:
print("Error creating configuration file", file=sys.stderr)
return
if args.list_sessions:
from pr.core.session import SessionManager
sm = SessionManager()
sessions = sm.list_sessions()
if not sessions:
print("No saved sessions found")
else:
print(f"Found {len(sessions)} saved sessions:\n")
for sess in sessions:
print(f" {sess['name']}")
print(f" Created: {sess['created_at']}")
print(f" Messages: {sess['message_count']}")
print()
return
if args.delete_session:
from pr.core.session import SessionManager
sm = SessionManager()
if sm.delete_session(args.delete_session):
print(f"Session '{args.delete_session}' deleted")
else:
print(f"Error deleting session '{args.delete_session}'", file=sys.stderr)
return
if args.export_session:
from pr.core.session import SessionManager
sm = SessionManager()
name, output_file = args.export_session
format_type = 'json'
if output_file.endswith('.md'):
format_type = 'markdown'
elif output_file.endswith('.txt'):
format_type = 'txt'
if sm.export_session(name, output_file, format_type):
print(f"Session exported to {output_file}")
else:
print(f"Error exporting session", file=sys.stderr)
return
if args.usage:
from pr.core.usage_tracker import UsageTracker
usage = UsageTracker.get_total_usage()
print(f"\nTotal Usage Statistics:")
print(f" Requests: {usage['total_requests']}")
print(f" Tokens: {usage['total_tokens']:,}")
print(f" Estimated Cost: ${usage['total_cost']:.4f}")
return
if args.plugins:
from pr.plugins.loader import PluginLoader
loader = PluginLoader()
loader.load_plugins()
plugins = loader.list_loaded_plugins()
if not plugins:
print("No plugins loaded")
else:
print(f"Loaded {len(plugins)} plugins:")
for plugin in plugins:
print(f" - {plugin}")
return
assistant = Assistant(args)
assistant.run()
if __name__ == '__main__':
main()

6
pr/agents/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from .agent_roles import AgentRole, get_agent_role, list_agent_roles
from .agent_manager import AgentManager, AgentInstance
from .agent_communication import AgentMessage, AgentCommunicationBus
__all__ = ['AgentRole', 'get_agent_role', 'list_agent_roles', 'AgentManager', 'AgentInstance',
'AgentMessage', 'AgentCommunicationBus']

View File

@ -0,0 +1,157 @@
import sqlite3
import json
from typing import List, Optional
from dataclasses import dataclass
from enum import Enum
class MessageType(Enum):
REQUEST = "request"
RESPONSE = "response"
NOTIFICATION = "notification"
@dataclass
class AgentMessage:
message_id: str
from_agent: str
to_agent: str
message_type: MessageType
content: str
metadata: dict
timestamp: float
def to_dict(self) -> dict:
return {
'message_id': self.message_id,
'from_agent': self.from_agent,
'to_agent': self.to_agent,
'message_type': self.message_type.value,
'content': self.content,
'metadata': self.metadata,
'timestamp': self.timestamp
}
@classmethod
def from_dict(cls, data: dict) -> 'AgentMessage':
return cls(
message_id=data['message_id'],
from_agent=data['from_agent'],
to_agent=data['to_agent'],
message_type=MessageType(data['message_type']),
content=data['content'],
metadata=data['metadata'],
timestamp=data['timestamp']
)
class AgentCommunicationBus:
def __init__(self, db_path: str):
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self._create_tables()
def _create_tables(self):
cursor = self.conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS agent_messages (
message_id TEXT PRIMARY KEY,
from_agent TEXT,
to_agent TEXT,
message_type TEXT,
content TEXT,
metadata TEXT,
timestamp REAL,
session_id TEXT,
read INTEGER DEFAULT 0
)
''')
self.conn.commit()
def send_message(self, message: AgentMessage, session_id: Optional[str] = None):
cursor = self.conn.cursor()
cursor.execute('''
INSERT INTO agent_messages
(message_id, from_agent, to_agent, message_type, content, metadata, timestamp, session_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
message.message_id,
message.from_agent,
message.to_agent,
message.message_type.value,
message.content,
json.dumps(message.metadata),
message.timestamp,
session_id
))
self.conn.commit()
def get_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
cursor = self.conn.cursor()
if unread_only:
cursor.execute('''
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages
WHERE to_agent = ? AND read = 0
ORDER BY timestamp ASC
''', (agent_id,))
else:
cursor.execute('''
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages
WHERE to_agent = ?
ORDER BY timestamp ASC
''', (agent_id,))
messages = []
for row in cursor.fetchall():
messages.append(AgentMessage(
message_id=row[0],
from_agent=row[1],
to_agent=row[2],
message_type=MessageType(row[3]),
content=row[4],
metadata=json.loads(row[5]) if row[5] else {},
timestamp=row[6]
))
return messages
def mark_as_read(self, message_id: str):
cursor = self.conn.cursor()
cursor.execute('UPDATE agent_messages SET read = 1 WHERE message_id = ?', (message_id,))
self.conn.commit()
def clear_messages(self, session_id: Optional[str] = None):
cursor = self.conn.cursor()
if session_id:
cursor.execute('DELETE FROM agent_messages WHERE session_id = ?', (session_id,))
else:
cursor.execute('DELETE FROM agent_messages')
self.conn.commit()
def close(self):
self.conn.close()
def receive_messages(self, agent_id: str) -> List[AgentMessage]:
return self.get_messages(agent_id, unread_only=True)
def get_conversation_history(self, agent_a: str, agent_b: str) -> List[AgentMessage]:
cursor = self.conn.cursor()
cursor.execute('''
SELECT message_id, from_agent, to_agent, message_type, content, metadata, timestamp
FROM agent_messages
WHERE (from_agent = ? AND to_agent = ?) OR (from_agent = ? AND to_agent = ?)
ORDER BY timestamp ASC
''', (agent_a, agent_b, agent_b, agent_a))
messages = []
for row in cursor.fetchall():
messages.append(AgentMessage(
message_id=row[0],
from_agent=row[1],
to_agent=row[2],
message_type=MessageType(row[3]),
content=row[4],
metadata=json.loads(row[5]) if row[5] else {},
timestamp=row[6]
))
return messages

186
pr/agents/agent_manager.py Normal file
View File

@ -0,0 +1,186 @@
import time
import json
import uuid
from typing import Dict, List, Any, Optional, Callable
from dataclasses import dataclass, field
from .agent_roles import AgentRole, get_agent_role
from .agent_communication import AgentMessage, AgentCommunicationBus, MessageType
@dataclass
class AgentInstance:
agent_id: str
role: AgentRole
message_history: List[Dict[str, Any]] = field(default_factory=list)
context: Dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
task_count: int = 0
def add_message(self, role: str, content: str):
self.message_history.append({
'role': role,
'content': content,
'timestamp': time.time()
})
def get_system_message(self) -> Dict[str, str]:
return {'role': 'system', 'content': self.role.system_prompt}
def get_messages_for_api(self) -> List[Dict[str, str]]:
return [self.get_system_message()] + [
{'role': msg['role'], 'content': msg['content']}
for msg in self.message_history
]
class AgentManager:
def __init__(self, db_path: str, api_caller: Callable):
self.db_path = db_path
self.api_caller = api_caller
self.communication_bus = AgentCommunicationBus(db_path)
self.active_agents: Dict[str, AgentInstance] = {}
self.session_id = str(uuid.uuid4())[:16]
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
if agent_id is None:
agent_id = f"{role_name}_{str(uuid.uuid4())[:8]}"
role = get_agent_role(role_name)
agent = AgentInstance(
agent_id=agent_id,
role=role
)
self.active_agents[agent_id] = agent
return agent_id
def get_agent(self, agent_id: str) -> Optional[AgentInstance]:
return self.active_agents.get(agent_id)
def remove_agent(self, agent_id: str) -> bool:
if agent_id in self.active_agents:
del self.active_agents[agent_id]
return True
return False
def execute_agent_task(self, agent_id: str, task: str, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
agent = self.get_agent(agent_id)
if not agent:
return {'error': f'Agent {agent_id} not found'}
if context:
agent.context.update(context)
agent.add_message('user', task)
agent.task_count += 1
messages = agent.get_messages_for_api()
try:
response = self.api_caller(
messages=messages,
temperature=agent.role.temperature,
max_tokens=agent.role.max_tokens
)
if response and 'choices' in response:
assistant_message = response['choices'][0]['message']['content']
agent.add_message('assistant', assistant_message)
return {
'success': True,
'agent_id': agent_id,
'response': assistant_message,
'role': agent.role.name,
'task_count': agent.task_count
}
else:
return {'error': 'Invalid API response', 'agent_id': agent_id}
except Exception as e:
return {'error': str(e), 'agent_id': agent_id}
def send_agent_message(self, from_agent_id: str, to_agent_id: str,
content: str, message_type: MessageType = MessageType.REQUEST,
metadata: Optional[Dict[str, Any]] = None):
message = AgentMessage(
from_agent=from_agent_id,
to_agent=to_agent_id,
message_type=message_type,
content=content,
metadata=metadata or {},
timestamp=time.time(),
message_id=str(uuid.uuid4())[:16]
)
self.communication_bus.send_message(message, self.session_id)
return message.message_id
def get_agent_messages(self, agent_id: str, unread_only: bool = True) -> List[AgentMessage]:
return self.communication_bus.get_messages(agent_id, unread_only)
def collaborate_agents(self, orchestrator_id: str, task: str, agent_roles: List[str]):
orchestrator = self.get_agent(orchestrator_id)
if not orchestrator:
orchestrator_id = self.create_agent('orchestrator')
orchestrator = self.get_agent(orchestrator_id)
worker_agents = []
for role in agent_roles:
agent_id = self.create_agent(role)
worker_agents.append({
'agent_id': agent_id,
'role': role
})
orchestration_prompt = f'''Task: {task}
Available specialized agents:
{chr(10).join([f"- {a['agent_id']} ({a['role']})" for a in worker_agents])}
Break down the task and delegate subtasks to appropriate agents. Coordinate their work and integrate results.'''
orchestrator_result = self.execute_agent_task(orchestrator_id, orchestration_prompt)
results = {
'orchestrator': orchestrator_result,
'agents': []
}
for agent_info in worker_agents:
agent_id = agent_info['agent_id']
messages = self.get_agent_messages(agent_id)
for msg in messages:
subtask = msg.content
result = self.execute_agent_task(agent_id, subtask)
results['agents'].append(result)
self.send_agent_message(
from_agent_id=agent_id,
to_agent_id=orchestrator_id,
content=result.get('response', ''),
message_type=MessageType.RESPONSE
)
self.communication_bus.mark_as_read(msg.message_id)
return results
def get_session_summary(self) -> Dict[str, Any]:
summary = {
'session_id': self.session_id,
'active_agents': len(self.active_agents),
'agents': [
{
'agent_id': agent_id,
'role': agent.role.name,
'task_count': agent.task_count,
'message_count': len(agent.message_history)
}
for agent_id, agent in self.active_agents.items()
]
}
return summary
def clear_session(self):
self.active_agents.clear()
self.communication_bus.clear_messages(session_id=self.session_id)
self.session_id = str(uuid.uuid4())[:16]

192
pr/agents/agent_roles.py Normal file
View File

@ -0,0 +1,192 @@
from dataclasses import dataclass
from typing import List, Dict, Any, Set
@dataclass
class AgentRole:
name: str
description: str
system_prompt: str
allowed_tools: Set[str]
specialization_areas: List[str]
temperature: float = 0.7
max_tokens: int = 4096
AGENT_ROLES = {
'coding': AgentRole(
name='coding',
description='Specialized in writing, reviewing, and debugging code',
system_prompt='''You are a coding specialist AI assistant. Your primary responsibilities:
- Write clean, efficient, well-structured code
- Review code for bugs, security issues, and best practices
- Refactor and optimize existing code
- Implement features based on specifications
- Follow language-specific conventions and patterns
Focus on code quality, maintainability, and performance.''',
allowed_tools={
'read_file', 'write_file', 'list_directory', 'create_directory',
'change_directory', 'get_current_directory', 'python_exec',
'run_command', 'index_directory'
},
specialization_areas=['code_writing', 'code_review', 'debugging', 'refactoring'],
temperature=0.3
),
'research': AgentRole(
name='research',
description='Specialized in information gathering and analysis',
system_prompt='''You are a research specialist AI assistant. Your primary responsibilities:
- Search for and gather relevant information
- Analyze data and documentation
- Synthesize findings into clear summaries
- Verify facts and cross-reference sources
- Identify trends and patterns in information
Focus on accuracy, thoroughness, and clear communication of findings.''',
allowed_tools={
'read_file', 'list_directory', 'index_directory',
'http_fetch', 'web_search', 'web_search_news',
'db_query', 'db_get'
},
specialization_areas=['information_gathering', 'analysis', 'documentation', 'fact_checking'],
temperature=0.5
),
'data_analysis': AgentRole(
name='data_analysis',
description='Specialized in data processing and analysis',
system_prompt='''You are a data analysis specialist AI assistant. Your primary responsibilities:
- Process and analyze structured and unstructured data
- Perform statistical analysis and pattern recognition
- Query databases and extract insights
- Create data summaries and reports
- Identify anomalies and trends
Focus on accuracy, data integrity, and actionable insights.''',
allowed_tools={
'db_query', 'db_get', 'db_set', 'read_file', 'write_file',
'python_exec', 'run_command', 'list_directory'
},
specialization_areas=['data_processing', 'statistical_analysis', 'database_operations'],
temperature=0.3
),
'planning': AgentRole(
name='planning',
description='Specialized in task planning and coordination',
system_prompt='''You are a planning specialist AI assistant. Your primary responsibilities:
- Break down complex tasks into manageable steps
- Create execution plans and workflows
- Identify dependencies and prerequisites
- Estimate effort and resource requirements
- Coordinate between different components
Focus on logical organization, completeness, and feasibility.''',
allowed_tools={
'read_file', 'write_file', 'list_directory', 'index_directory',
'db_set', 'db_get'
},
specialization_areas=['task_decomposition', 'workflow_design', 'coordination'],
temperature=0.6
),
'testing': AgentRole(
name='testing',
description='Specialized in testing and quality assurance',
system_prompt='''You are a testing specialist AI assistant. Your primary responsibilities:
- Design and execute test cases
- Identify edge cases and potential failures
- Verify functionality and correctness
- Test error handling and edge conditions
- Ensure code meets quality standards
Focus on thoroughness, coverage, and issue identification.''',
allowed_tools={
'read_file', 'write_file', 'python_exec', 'run_command',
'list_directory', 'db_query'
},
specialization_areas=['test_design', 'quality_assurance', 'validation'],
temperature=0.4
),
'documentation': AgentRole(
name='documentation',
description='Specialized in creating and maintaining documentation',
system_prompt='''You are a documentation specialist AI assistant. Your primary responsibilities:
- Write clear, comprehensive documentation
- Create API references and user guides
- Document code with comments and docstrings
- Organize and structure information logically
- Ensure documentation is up-to-date and accurate
Focus on clarity, completeness, and user-friendliness.''',
allowed_tools={
'read_file', 'write_file', 'list_directory', 'index_directory',
'http_fetch', 'web_search'
},
specialization_areas=['technical_writing', 'documentation_organization', 'user_guides'],
temperature=0.6
),
'orchestrator': AgentRole(
name='orchestrator',
description='Coordinates multiple agents and manages overall execution',
system_prompt='''You are an orchestrator AI assistant. Your primary responsibilities:
- Coordinate multiple specialized agents
- Delegate tasks to appropriate agents
- Integrate results from different agents
- Manage overall workflow execution
- Ensure task completion and quality
Focus on effective delegation, integration, and overall success.''',
allowed_tools={
'read_file', 'write_file', 'list_directory', 'db_set', 'db_get', 'db_query'
},
specialization_areas=['agent_coordination', 'task_delegation', 'result_integration'],
temperature=0.5
),
'general': AgentRole(
name='general',
description='General purpose agent for miscellaneous tasks',
system_prompt='''You are a general purpose AI assistant. Your responsibilities:
- Handle diverse tasks across multiple domains
- Provide balanced assistance for various needs
- Adapt to different types of requests
- Collaborate with specialized agents when needed
Focus on versatility, helpfulness, and task completion.''',
allowed_tools={
'read_file', 'write_file', 'list_directory', 'create_directory',
'change_directory', 'get_current_directory', 'python_exec',
'run_command', 'run_command_interactive', 'http_fetch',
'web_search', 'web_search_news', 'db_set', 'db_get', 'db_query',
'index_directory'
},
specialization_areas=['general_assistance'],
temperature=0.7
)
}
def get_agent_role(role_name: str) -> AgentRole:
return AGENT_ROLES.get(role_name, AGENT_ROLES['general'])
def list_agent_roles() -> Dict[str, AgentRole]:
return AGENT_ROLES.copy()
def get_recommended_agent(task_description: str) -> str:
task_lower = task_description.lower()
code_keywords = ['code', 'implement', 'function', 'class', 'bug', 'debug', 'refactor', 'optimize']
research_keywords = ['search', 'find', 'research', 'information', 'analyze', 'investigate']
data_keywords = ['data', 'database', 'query', 'statistics', 'analyze', 'process']
planning_keywords = ['plan', 'organize', 'workflow', 'steps', 'coordinate']
testing_keywords = ['test', 'verify', 'validate', 'check', 'quality']
doc_keywords = ['document', 'documentation', 'explain', 'guide', 'manual']
if any(keyword in task_lower for keyword in code_keywords):
return 'coding'
elif any(keyword in task_lower for keyword in research_keywords):
return 'research'
elif any(keyword in task_lower for keyword in data_keywords):
return 'data_analysis'
elif any(keyword in task_lower for keyword in planning_keywords):
return 'planning'
elif any(keyword in task_lower for keyword in testing_keywords):
return 'testing'
elif any(keyword in task_lower for keyword in doc_keywords):
return 'documentation'
else:
return 'general'

View File

@ -0,0 +1,4 @@
from pr.autonomous.detection import is_task_complete
from pr.autonomous.mode import run_autonomous_mode, process_response_autonomous
__all__ = ['is_task_complete', 'run_autonomous_mode', 'process_response_autonomous']

View File

@ -0,0 +1,42 @@
from pr.config import MAX_AUTONOMOUS_ITERATIONS
from pr.ui import Colors
def is_task_complete(response, iteration):
if 'error' in response:
return True
if 'choices' not in response or not response['choices']:
return True
message = response['choices'][0]['message']
content = message.get('content', '').lower()
completion_keywords = [
'task complete', 'task is complete', 'finished', 'done',
'successfully completed', 'task accomplished', 'all done',
'implementation complete', 'setup complete', 'installation complete'
]
error_keywords = [
'cannot proceed', 'unable to continue', 'fatal error',
'cannot complete', 'impossible to'
]
has_tool_calls = 'tool_calls' in message and message['tool_calls']
mentions_completion = any(keyword in content for keyword in completion_keywords)
mentions_error = any(keyword in content for keyword in error_keywords)
if mentions_error:
return True
if mentions_completion and not has_tool_calls:
return True
if iteration > 5 and not has_tool_calls:
return True
if iteration >= MAX_AUTONOMOUS_ITERATIONS:
print(f"{Colors.YELLOW}⚠ Maximum iterations reached{Colors.RESET}")
return True
return False

200
pr/autonomous/mode.py Normal file
View File

@ -0,0 +1,200 @@
import time
import json
import logging
from pr.ui import Colors, display_tool_call, print_autonomous_header
from pr.autonomous.detection import is_task_complete
from pr.core.context import truncate_tool_result
logger = logging.getLogger('pr')
def run_autonomous_mode(assistant, task):
assistant.autonomous_mode = True
assistant.autonomous_iterations = 0
logger.debug(f"=== AUTONOMOUS MODE START ===")
logger.debug(f"Task: {task}")
if assistant.verbose:
print_autonomous_header(task)
assistant.messages.append({
"role": "user",
"content": f"AUTONOMOUS TASK: {task}\n\nPlease work on this task step by step. Use tools as needed. When the task is fully complete, clearly state 'Task complete'."
})
try:
while True:
assistant.autonomous_iterations += 1
logger.debug(f"--- Autonomous iteration {assistant.autonomous_iterations} ---")
logger.debug(f"Messages before context management: {len(assistant.messages)}")
if assistant.verbose:
print(f"\n{Colors.BOLD}{Colors.MAGENTA}{'' * 3} Iteration {assistant.autonomous_iterations} {'' * 3}{Colors.RESET}\n")
from pr.core.context import manage_context_window
assistant.messages = manage_context_window(assistant.messages, assistant.verbose)
logger.debug(f"Messages after context management: {len(assistant.messages)}")
if assistant.verbose:
print(f"{Colors.GRAY}Calling API...{Colors.RESET}")
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
response = call_api(
assistant.messages,
assistant.model,
assistant.api_url,
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose
)
if 'error' in response:
logger.error(f"API error in autonomous mode: {response['error']}")
print(f"{Colors.RED}Error: {response['error']}{Colors.RESET}")
break
is_complete = is_task_complete(response, assistant.autonomous_iterations)
logger.debug(f"Task completion check: {is_complete}")
if is_complete:
result = process_response_autonomous(assistant, response)
print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n")
logger.debug(f"=== AUTONOMOUS MODE COMPLETE ===")
logger.debug(f"Total iterations: {assistant.autonomous_iterations}")
logger.debug(f"Final message count: {len(assistant.messages)}")
print(f"{Colors.BOLD}Total Iterations:{Colors.RESET} {assistant.autonomous_iterations}")
print(f"{Colors.BOLD}Messages in Context:{Colors.RESET} {len(assistant.messages)}\n")
break
result = process_response_autonomous(assistant, response)
if result:
print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n")
time.sleep(0.5)
except KeyboardInterrupt:
logger.debug("Autonomous mode interrupted by user")
print(f"\n{Colors.YELLOW}Autonomous mode interrupted by user{Colors.RESET}")
finally:
assistant.autonomous_mode = False
logger.debug("=== AUTONOMOUS MODE END ===")
def process_response_autonomous(assistant, response):
if 'error' in response:
return f"Error: {response['error']}"
if 'choices' not in response or not response['choices']:
return "No response from API"
message = response['choices'][0]['message']
assistant.messages.append(message)
if 'tool_calls' in message and message['tool_calls']:
print(f"{Colors.BOLD}{Colors.CYAN}🔧 Executing {len(message['tool_calls'])} tool(s)...{Colors.RESET}\n")
tool_results = []
for tool_call in message['tool_calls']:
func_name = tool_call['function']['name']
arguments = json.loads(tool_call['function']['arguments'])
display_tool_call(func_name, arguments, "running")
result = execute_single_tool(assistant, func_name, arguments)
result = truncate_tool_result(result)
status = "success" if result.get("status") == "success" else "error"
display_tool_call(func_name, arguments, status, result)
tool_results.append({
"tool_call_id": tool_call['id'],
"role": "tool",
"content": json.dumps(result)
})
for result in tool_results:
assistant.messages.append(result)
print(f"{Colors.GRAY}Processing tool results...{Colors.RESET}\n")
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
follow_up = call_api(
assistant.messages,
assistant.model,
assistant.api_url,
assistant.api_key,
assistant.use_tools,
get_tools_definition(),
verbose=assistant.verbose
)
return process_response_autonomous(assistant, follow_up)
content = message.get('content', '')
from pr.ui import render_markdown
return render_markdown(content, assistant.syntax_highlighting)
def execute_single_tool(assistant, func_name, arguments):
logger.debug(f"Executing tool in autonomous mode: {func_name}")
logger.debug(f"Tool arguments: {arguments}")
from pr.tools import (
http_fetch, run_command, run_command_interactive, read_file, write_file,
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query,
web_search, web_search_news, python_exec, index_source_directory,
search_replace, open_editor, editor_insert_text, editor_replace_text,
editor_search, close_editor, create_diff, apply_patch, tail_process, kill_process
)
from pr.tools.patch import display_file_diff
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
func_map = {
'http_fetch': lambda **kw: http_fetch(**kw),
'run_command': lambda **kw: run_command(**kw),
'tail_process': lambda **kw: tail_process(**kw),
'kill_process': lambda **kw: kill_process(**kw),
'run_command_interactive': lambda **kw: run_command_interactive(**kw),
'read_file': lambda **kw: read_file(**kw),
'write_file': lambda **kw: write_file(**kw, db_conn=assistant.db_conn),
'list_directory': lambda **kw: list_directory(**kw),
'mkdir': lambda **kw: mkdir(**kw),
'chdir': lambda **kw: chdir(**kw),
'getpwd': lambda **kw: getpwd(**kw),
'db_set': lambda **kw: db_set(**kw, db_conn=assistant.db_conn),
'db_get': lambda **kw: db_get(**kw, db_conn=assistant.db_conn),
'db_query': lambda **kw: db_query(**kw, db_conn=assistant.db_conn),
'web_search': lambda **kw: web_search(**kw),
'web_search_news': lambda **kw: web_search_news(**kw),
'python_exec': lambda **kw: python_exec(**kw, python_globals=assistant.python_globals),
'index_source_directory': lambda **kw: index_source_directory(**kw),
'search_replace': lambda **kw: search_replace(**kw),
'open_editor': lambda **kw: open_editor(**kw),
'editor_insert_text': lambda **kw: editor_insert_text(**kw),
'editor_replace_text': lambda **kw: editor_replace_text(**kw),
'editor_search': lambda **kw: editor_search(**kw),
'close_editor': lambda **kw: close_editor(**kw),
'create_diff': lambda **kw: create_diff(**kw),
'apply_patch': lambda **kw: apply_patch(**kw),
'display_file_diff': lambda **kw: display_file_diff(**kw),
'display_edit_summary': lambda **kw: display_edit_summary(),
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
}
if func_name in func_map:
try:
result = func_map[func_name](**arguments)
logger.debug(f"Tool execution result: {str(result)[:200]}...")
return result
except Exception as e:
logger.error(f"Tool execution error: {str(e)}")
return {"status": "error", "error": str(e)}
else:
logger.error(f"Unknown function requested: {func_name}")
return {"status": "error", "error": f"Unknown function: {func_name}"}

4
pr/cache/__init__.py vendored Normal file
View File

@ -0,0 +1,4 @@
from .api_cache import APICache
from .tool_cache import ToolCache
__all__ = ['APICache', 'ToolCache']

127
pr/cache/api_cache.py vendored Normal file
View File

@ -0,0 +1,127 @@
import hashlib
import json
import sqlite3
import time
from typing import Optional, Dict, Any
class APICache:
def __init__(self, db_path: str, ttl_seconds: int = 3600):
self.db_path = db_path
self.ttl_seconds = ttl_seconds
self._initialize_cache()
def _initialize_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS api_cache (
cache_key TEXT PRIMARY KEY,
response_data TEXT NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
model TEXT,
token_count INTEGER
)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_expires_at ON api_cache(expires_at)
''')
conn.commit()
conn.close()
def _generate_cache_key(self, model: str, messages: list, temperature: float, max_tokens: int) -> str:
cache_data = {
'model': model,
'messages': messages,
'temperature': temperature,
'max_tokens': max_tokens
}
serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest()
def get(self, model: str, messages: list, temperature: float, max_tokens: int) -> Optional[Dict[str, Any]]:
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
cursor.execute('''
SELECT response_data FROM api_cache
WHERE cache_key = ? AND expires_at > ?
''', (cache_key, current_time))
row = cursor.fetchone()
conn.close()
if row:
return json.loads(row[0])
return None
def set(self, model: str, messages: list, temperature: float, max_tokens: int,
response: Dict[str, Any], token_count: int = 0):
cache_key = self._generate_cache_key(model, messages, temperature, max_tokens)
current_time = int(time.time())
expires_at = current_time + self.ttl_seconds
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO api_cache
(cache_key, response_data, created_at, expires_at, model, token_count)
VALUES (?, ?, ?, ?, ?, ?)
''', (cache_key, json.dumps(response), current_time, expires_at, model, token_count))
conn.commit()
conn.close()
def clear_expired(self):
current_time = int(time.time())
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM api_cache WHERE expires_at <= ?', (current_time,))
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def clear_all(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM api_cache')
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def get_statistics(self) -> Dict[str, Any]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM api_cache')
total_entries = cursor.fetchone()[0]
current_time = int(time.time())
cursor.execute('SELECT COUNT(*) FROM api_cache WHERE expires_at > ?', (current_time,))
valid_entries = cursor.fetchone()[0]
cursor.execute('SELECT SUM(token_count) FROM api_cache WHERE expires_at > ?', (current_time,))
total_tokens = cursor.fetchone()[0] or 0
conn.close()
return {
'total_entries': total_entries,
'valid_entries': valid_entries,
'expired_entries': total_entries - valid_entries,
'total_cached_tokens': total_tokens
}

179
pr/cache/tool_cache.py vendored Normal file
View File

@ -0,0 +1,179 @@
import hashlib
import json
import sqlite3
import time
from typing import Optional, Any, Set
class ToolCache:
DETERMINISTIC_TOOLS: Set[str] = {
'read_file',
'list_directory',
'get_current_directory',
'db_get',
'db_query',
'index_directory'
}
def __init__(self, db_path: str, ttl_seconds: int = 300):
self.db_path = db_path
self.ttl_seconds = ttl_seconds
self._initialize_cache()
def _initialize_cache(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS tool_cache (
cache_key TEXT PRIMARY KEY,
tool_name TEXT NOT NULL,
result_data TEXT NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
hit_count INTEGER DEFAULT 0
)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_tool_expires ON tool_cache(expires_at)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_cache(tool_name)
''')
conn.commit()
conn.close()
def _generate_cache_key(self, tool_name: str, arguments: dict) -> str:
cache_data = {
'tool': tool_name,
'args': arguments
}
serialized = json.dumps(cache_data, sort_keys=True)
return hashlib.sha256(serialized.encode()).hexdigest()
def is_cacheable(self, tool_name: str) -> bool:
return tool_name in self.DETERMINISTIC_TOOLS
def get(self, tool_name: str, arguments: dict) -> Optional[Any]:
if not self.is_cacheable(tool_name):
return None
cache_key = self._generate_cache_key(tool_name, arguments)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
cursor.execute('''
SELECT result_data, hit_count FROM tool_cache
WHERE cache_key = ? AND expires_at > ?
''', (cache_key, current_time))
row = cursor.fetchone()
if row:
cursor.execute('''
UPDATE tool_cache SET hit_count = hit_count + 1
WHERE cache_key = ?
''', (cache_key,))
conn.commit()
conn.close()
return json.loads(row[0])
conn.close()
return None
def set(self, tool_name: str, arguments: dict, result: Any):
if not self.is_cacheable(tool_name):
return
cache_key = self._generate_cache_key(tool_name, arguments)
current_time = int(time.time())
expires_at = current_time + self.ttl_seconds
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO tool_cache
(cache_key, tool_name, result_data, created_at, expires_at, hit_count)
VALUES (?, ?, ?, ?, ?, 0)
''', (cache_key, tool_name, json.dumps(result), current_time, expires_at))
conn.commit()
conn.close()
def invalidate_tool(self, tool_name: str):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache WHERE tool_name = ?', (tool_name,))
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def clear_expired(self):
current_time = int(time.time())
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache WHERE expires_at <= ?', (current_time,))
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def clear_all(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM tool_cache')
deleted_count = cursor.rowcount
conn.commit()
conn.close()
return deleted_count
def get_statistics(self) -> dict:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM tool_cache')
total_entries = cursor.fetchone()[0]
current_time = int(time.time())
cursor.execute('SELECT COUNT(*) FROM tool_cache WHERE expires_at > ?', (current_time,))
valid_entries = cursor.fetchone()[0]
cursor.execute('SELECT SUM(hit_count) FROM tool_cache WHERE expires_at > ?', (current_time,))
total_hits = cursor.fetchone()[0] or 0
cursor.execute('''
SELECT tool_name, COUNT(*), SUM(hit_count)
FROM tool_cache
WHERE expires_at > ?
GROUP BY tool_name
''', (current_time,))
tool_stats = {}
for row in cursor.fetchall():
tool_stats[row[0]] = {
'cached_entries': row[1],
'total_hits': row[2] or 0
}
conn.close()
return {
'total_entries': total_entries,
'valid_entries': valid_entries,
'expired_entries': total_entries - valid_entries,
'total_cache_hits': total_hits,
'by_tool': tool_stats
}

3
pr/commands/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from pr.commands.handlers import handle_command
__all__ = ['handle_command']

391
pr/commands/handlers.py Normal file
View File

@ -0,0 +1,391 @@
import json
from pr.ui import Colors
from pr.tools import read_file
from pr.tools.base import get_tools_definition
from pr.core.api import list_models
from pr.autonomous import run_autonomous_mode
def handle_command(assistant, command):
command_parts = command.strip().split(maxsplit=1)
cmd = command_parts[0].lower()
if cmd == '/auto':
if len(command_parts) < 2:
print(f"{Colors.RED}Usage: /auto [task description]{Colors.RESET}")
print(f"{Colors.GRAY}Example: /auto Create a Python web scraper for news sites{Colors.RESET}")
return True
task = command_parts[1]
run_autonomous_mode(assistant, task)
return True
if cmd in ['exit', 'quit', 'q']:
return False
elif cmd == 'help':
print(f"""
{Colors.BOLD}Available Commands:{Colors.RESET}
{Colors.BOLD}Basic:{Colors.RESET}
exit, quit, q - Exit the assistant
/help - Show this help message
/reset - Clear message history
/dump - Show message history as JSON
/verbose - Toggle verbose mode
/models - List available models
/tools - List available tools
{Colors.BOLD}File Operations:{Colors.RESET}
/review <file> - Review a file
/refactor <file> - Refactor code in a file
/obfuscate <file> - Obfuscate code in a file
{Colors.BOLD}Advanced Features:{Colors.RESET}
{Colors.CYAN}/auto <task>{Colors.RESET} - Enter autonomous mode
{Colors.CYAN}/workflow <name>{Colors.RESET} - Execute a workflow
{Colors.CYAN}/workflows{Colors.RESET} - List all workflows
{Colors.CYAN}/agent <role> <task>{Colors.RESET} - Create specialized agent and assign task
{Colors.CYAN}/agents{Colors.RESET} - Show active agents
{Colors.CYAN}/collaborate <task>{Colors.RESET} - Use multiple agents to collaborate
{Colors.CYAN}/knowledge <query>{Colors.RESET} - Search knowledge base
{Colors.CYAN}/remember <content>{Colors.RESET} - Store information in knowledge base
{Colors.CYAN}/history{Colors.RESET} - Show conversation history
{Colors.CYAN}/cache{Colors.RESET} - Show cache statistics
{Colors.CYAN}/cache clear{Colors.RESET} - Clear all caches
{Colors.CYAN}/stats{Colors.RESET} - Show system statistics
""")
elif cmd == '/reset':
assistant.messages = assistant.messages[:1]
print(f"{Colors.GREEN}Message history cleared{Colors.RESET}")
elif cmd == '/dump':
print(json.dumps(assistant.messages, indent=2))
elif cmd == '/verbose':
assistant.verbose = not assistant.verbose
print(f"Verbose mode: {Colors.GREEN if assistant.verbose else Colors.RED}{'ON' if assistant.verbose else 'OFF'}{Colors.RESET}")
elif cmd.startswith("/model"):
if len(command_parts) < 2:
print("Current model: " + Colors.GREEN + assistant.model + Colors.RESET)
else:
assistant.model = command_parts[1]
print(f"Model set to: {Colors.GREEN}{assistant.model}{Colors.RESET}")
elif cmd == '/models':
models = list_models(assistant.model_list_url, assistant.api_key)
if isinstance(models, dict) and 'error' in models:
print(f"{Colors.RED}Error fetching models: {models['error']}{Colors.RESET}")
else:
print(f"{Colors.BOLD}Available Models:{Colors.RESET}")
for model in models:
print(f"{Colors.CYAN}{model['id']}{Colors.RESET}")
elif cmd == '/tools':
print(f"{Colors.BOLD}Available Tools:{Colors.RESET}")
for tool in get_tools_definition():
func = tool['function']
print(f"{Colors.CYAN}{func['name']}{Colors.RESET}: {func['description']}")
elif cmd == '/review' and len(command_parts) > 1:
filename = command_parts[1]
review_file(assistant, filename)
elif cmd == '/refactor' and len(command_parts) > 1:
filename = command_parts[1]
refactor_file(assistant, filename)
elif cmd == '/obfuscate' and len(command_parts) > 1:
filename = command_parts[1]
obfuscate_file(assistant, filename)
elif cmd == '/workflows':
show_workflows(assistant)
elif cmd == '/workflow' and len(command_parts) > 1:
workflow_name = command_parts[1]
execute_workflow_command(assistant, workflow_name)
elif cmd == '/agent' and len(command_parts) > 1:
args = command_parts[1].split(maxsplit=1)
if len(args) < 2:
print(f"{Colors.RED}Usage: /agent <role> <task>{Colors.RESET}")
print(f"{Colors.GRAY}Available roles: coding, research, data_analysis, planning, testing, documentation{Colors.RESET}")
else:
role, task = args[0], args[1]
execute_agent_task(assistant, role, task)
elif cmd == '/agents':
show_agents(assistant)
elif cmd == '/collaborate' and len(command_parts) > 1:
task = command_parts[1]
collaborate_agents_command(assistant, task)
elif cmd == '/knowledge' and len(command_parts) > 1:
query = command_parts[1]
search_knowledge(assistant, query)
elif cmd == '/remember' and len(command_parts) > 1:
content = command_parts[1]
store_knowledge(assistant, content)
elif cmd == '/history':
show_conversation_history(assistant)
elif cmd == '/cache':
if len(command_parts) > 1 and command_parts[1].lower() == 'clear':
clear_caches(assistant)
else:
show_cache_stats(assistant)
elif cmd == '/stats':
show_system_stats(assistant)
else:
return None
return True
def review_file(assistant, filename):
result = read_file(filename)
if result['status'] == 'success':
message = f"Please review this file and provide feedback:\n\n{result['content']}"
from pr.core.assistant import process_message
process_message(assistant, message)
else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def refactor_file(assistant, filename):
result = read_file(filename)
if result['status'] == 'success':
message = f"Please refactor this code to improve its quality:\n\n{result['content']}"
from pr.core.assistant import process_message
process_message(assistant, message)
else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def obfuscate_file(assistant, filename):
result = read_file(filename)
if result['status'] == 'success':
message = f"Please obfuscate this code:\n\n{result['content']}"
from pr.core.assistant import process_message
process_message(assistant, message)
else:
print(f"{Colors.RED}Error reading file: {result['error']}{Colors.RESET}")
def show_workflows(assistant):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
workflows = assistant.enhanced.get_workflow_list()
if not workflows:
print(f"{Colors.YELLOW}No workflows found{Colors.RESET}")
return
print(f"\n{Colors.BOLD}Available Workflows:{Colors.RESET}")
for wf in workflows:
print(f"{Colors.CYAN}{wf['name']}{Colors.RESET}: {wf['description']}")
print(f" Executions: {wf['execution_count']}")
def execute_workflow_command(assistant, workflow_name):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
print(f"{Colors.YELLOW}Executing workflow: {workflow_name}...{Colors.RESET}")
result = assistant.enhanced.execute_workflow(workflow_name)
if 'error' in result:
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
else:
print(f"{Colors.GREEN}Workflow completed successfully{Colors.RESET}")
print(f"Execution ID: {result['execution_id']}")
print(f"Results: {json.dumps(result['results'], indent=2)}")
def execute_agent_task(assistant, role, task):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
print(f"{Colors.YELLOW}Creating {role} agent...{Colors.RESET}")
agent_id = assistant.enhanced.create_agent(role)
print(f"{Colors.GREEN}Agent created: {agent_id}{Colors.RESET}")
print(f"{Colors.YELLOW}Executing task...{Colors.RESET}")
result = assistant.enhanced.agent_task(agent_id, task)
if 'error' in result:
print(f"{Colors.RED}Error: {result['error']}{Colors.RESET}")
else:
print(f"\n{Colors.GREEN}{role.capitalize()} Agent Response:{Colors.RESET}")
print(result['response'])
def show_agents(assistant):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
summary = assistant.enhanced.get_agent_summary()
print(f"\n{Colors.BOLD}Agent Session Summary:{Colors.RESET}")
print(f"Active agents: {summary['active_agents']}")
if summary['agents']:
for agent in summary['agents']:
print(f"\n{Colors.CYAN}{agent['agent_id']}{Colors.RESET}")
print(f" Role: {agent['role']}")
print(f" Tasks completed: {agent['task_count']}")
print(f" Messages: {agent['message_count']}")
def collaborate_agents_command(assistant, task):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
print(f"{Colors.YELLOW}Initiating agent collaboration...{Colors.RESET}")
roles = ['coding', 'research', 'planning']
result = assistant.enhanced.collaborate_agents(task, roles)
print(f"\n{Colors.GREEN}Collaboration completed{Colors.RESET}")
print(f"\nOrchestrator response:")
if 'orchestrator' in result and 'response' in result['orchestrator']:
print(result['orchestrator']['response'])
if result.get('agents'):
print(f"\n{Colors.BOLD}Agent Results:{Colors.RESET}")
for agent_result in result['agents']:
if 'role' in agent_result:
print(f"\n{Colors.CYAN}{agent_result['role']}:{Colors.RESET}")
print(agent_result.get('response', 'No response'))
def search_knowledge(assistant, query):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
results = assistant.enhanced.search_knowledge(query)
if not results:
print(f"{Colors.YELLOW}No knowledge entries found for: {query}{Colors.RESET}")
return
print(f"\n{Colors.BOLD}Knowledge Search Results:{Colors.RESET}")
for entry in results:
print(f"\n{Colors.CYAN}[{entry.category}]{Colors.RESET}")
print(f" {entry.content[:200]}...")
print(f" Accessed: {entry.access_count} times")
def store_knowledge(assistant, content):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
import uuid
import time
from pr.memory import KnowledgeEntry
categories = assistant.enhanced.fact_extractor.categorize_content(content)
entry_id = str(uuid.uuid4())[:16]
entry = KnowledgeEntry(
entry_id=entry_id,
category=categories[0] if categories else 'general',
content=content,
metadata={'manual_entry': True},
created_at=time.time(),
updated_at=time.time()
)
assistant.enhanced.knowledge_store.add_entry(entry)
print(f"{Colors.GREEN}Knowledge stored successfully{Colors.RESET}")
print(f"Entry ID: {entry_id}")
print(f"Category: {entry.category}")
def show_conversation_history(assistant):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
history = assistant.enhanced.get_conversation_history(limit=10)
if not history:
print(f"{Colors.YELLOW}No conversation history found{Colors.RESET}")
return
print(f"\n{Colors.BOLD}Recent Conversations:{Colors.RESET}")
for conv in history:
import datetime
started = datetime.datetime.fromtimestamp(conv['started_at']).strftime('%Y-%m-%d %H:%M')
print(f"\n{Colors.CYAN}{conv['conversation_id']}{Colors.RESET}")
print(f" Started: {started}")
print(f" Messages: {conv['message_count']}")
if conv.get('summary'):
print(f" Summary: {conv['summary'][:100]}...")
if conv.get('topics'):
print(f" Topics: {', '.join(conv['topics'])}")
def show_cache_stats(assistant):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
stats = assistant.enhanced.get_cache_statistics()
print(f"\n{Colors.BOLD}Cache Statistics:{Colors.RESET}")
if 'api_cache' in stats:
api_stats = stats['api_cache']
print(f"\n{Colors.CYAN}API Cache:{Colors.RESET}")
print(f" Total entries: {api_stats['total_entries']}")
print(f" Valid entries: {api_stats['valid_entries']}")
print(f" Expired entries: {api_stats['expired_entries']}")
print(f" Cached tokens: {api_stats['total_cached_tokens']}")
if 'tool_cache' in stats:
tool_stats = stats['tool_cache']
print(f"\n{Colors.CYAN}Tool Cache:{Colors.RESET}")
print(f" Total entries: {tool_stats['total_entries']}")
print(f" Valid entries: {tool_stats['valid_entries']}")
print(f" Total cache hits: {tool_stats['total_cache_hits']}")
if tool_stats.get('by_tool'):
print(f"\n Per-tool statistics:")
for tool_name, tool_stat in tool_stats['by_tool'].items():
print(f" {tool_name}: {tool_stat['cached_entries']} entries, {tool_stat['total_hits']} hits")
def clear_caches(assistant):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
assistant.enhanced.clear_caches()
print(f"{Colors.GREEN}All caches cleared successfully{Colors.RESET}")
def show_system_stats(assistant):
if not hasattr(assistant, 'enhanced'):
print(f"{Colors.YELLOW}Enhanced features not initialized{Colors.RESET}")
return
print(f"\n{Colors.BOLD}System Statistics:{Colors.RESET}")
cache_stats = assistant.enhanced.get_cache_statistics()
knowledge_stats = assistant.enhanced.get_knowledge_statistics()
agent_summary = assistant.enhanced.get_agent_summary()
print(f"\n{Colors.CYAN}Knowledge Base:{Colors.RESET}")
print(f" Total entries: {knowledge_stats['total_entries']}")
print(f" Categories: {knowledge_stats['total_categories']}")
print(f" Total accesses: {knowledge_stats['total_accesses']}")
print(f" Vocabulary size: {knowledge_stats['vocabulary_size']}")
print(f"\n{Colors.CYAN}Active Agents:{Colors.RESET}")
print(f" Count: {agent_summary['active_agents']}")
if 'api_cache' in cache_stats:
print(f"\n{Colors.CYAN}Caching:{Colors.RESET}")
print(f" API cache entries: {cache_stats['api_cache']['valid_entries']}")
if 'tool_cache' in cache_stats:
print(f" Tool cache entries: {cache_stats['tool_cache']['valid_entries']}")

61
pr/config.py Normal file
View File

@ -0,0 +1,61 @@
import os
DEFAULT_MODEL = "x-ai/grok-code-fast-1"
DEFAULT_API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL_LIST_URL = "https://openrouter.ai/api/v1/models"
DB_PATH = os.path.expanduser("~/.assistant_db.sqlite")
LOG_FILE = os.path.expanduser("~/.assistant_error.log")
CONTEXT_FILE = ".rcontext.txt"
GLOBAL_CONTEXT_FILE = os.path.expanduser("~/.rcontext.txt")
HISTORY_FILE = os.path.expanduser("~/.assistant_history")
DEFAULT_TEMPERATURE = 0.1
DEFAULT_MAX_TOKENS = 4096
MAX_AUTONOMOUS_ITERATIONS = 50
CONTEXT_COMPRESSION_THRESHOLD = 15
RECENT_MESSAGES_TO_KEEP = 20
API_TOTAL_TOKEN_LIMIT = 256000
MAX_OUTPUT_TOKENS = 30000
SAFETY_BUFFER_TOKENS = 30000
MAX_TOKENS_LIMIT = API_TOTAL_TOKEN_LIMIT - MAX_OUTPUT_TOKENS - SAFETY_BUFFER_TOKENS
CHARS_PER_TOKEN = 2.0
EMERGENCY_MESSAGES_TO_KEEP = 3
CONTENT_TRIM_LENGTH = 30000
MAX_TOOL_RESULT_LENGTH = 30000
LANGUAGE_KEYWORDS = {
'python': ['def', 'class', 'import', 'from', 'if', 'else', 'elif', 'for', 'while',
'return', 'try', 'except', 'finally', 'with', 'as', 'lambda', 'yield',
'None', 'True', 'False', 'and', 'or', 'not', 'in', 'is'],
'javascript': ['function', 'var', 'let', 'const', 'if', 'else', 'for', 'while',
'return', 'try', 'catch', 'finally', 'class', 'extends', 'new',
'this', 'null', 'undefined', 'true', 'false'],
'java': ['public', 'private', 'protected', 'class', 'interface', 'extends',
'implements', 'static', 'final', 'void', 'int', 'String', 'boolean',
'if', 'else', 'for', 'while', 'return', 'try', 'catch', 'finally'],
}
CACHE_ENABLED = True
API_CACHE_TTL = 3600
TOOL_CACHE_TTL = 300
WORKFLOW_MAX_RETRIES = 3
WORKFLOW_DEFAULT_TIMEOUT = 300
WORKFLOW_EXECUTOR_MAX_WORKERS = 5
AGENT_DEFAULT_TEMPERATURE = 0.7
AGENT_MAX_WORKERS = 3
AGENT_SESSION_TIMEOUT = 7200
KNOWLEDGE_IMPORTANCE_THRESHOLD = 0.5
KNOWLEDGE_SEARCH_LIMIT = 5
MEMORY_AUTO_SUMMARIZE = True
CONVERSATION_SUMMARY_THRESHOLD = 20
ADVANCED_CONTEXT_ENABLED = True
CONTEXT_RELEVANCE_THRESHOLD = 0.3
ADAPTIVE_CONTEXT_MIN = 10
ADAPTIVE_CONTEXT_MAX = 50

5
pr/core/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from pr.core.assistant import Assistant
from pr.core.api import call_api, list_models
from pr.core.context import init_system_message, manage_context_window
__all__ = ['Assistant', 'call_api', 'list_models', 'init_system_message', 'manage_context_window']

View File

@ -0,0 +1,82 @@
import re
import math
from typing import List, Dict, Any
from collections import Counter
class AdvancedContextManager:
def __init__(self, knowledge_store=None, conversation_memory=None):
self.knowledge_store = knowledge_store
self.conversation_memory = conversation_memory
def adaptive_context_window(self, messages: List[Dict[str, Any]],
task_complexity: str = 'medium') -> int:
complexity_thresholds = {
'simple': 10,
'medium': 20,
'complex': 35,
'very_complex': 50
}
base_threshold = complexity_thresholds.get(task_complexity, 20)
message_complexity_score = self._analyze_message_complexity(messages)
if message_complexity_score > 0.7:
adjusted = int(base_threshold * 1.5)
elif message_complexity_score < 0.3:
adjusted = int(base_threshold * 0.7)
else:
adjusted = base_threshold
return max(base_threshold, adjusted)
def _analyze_message_complexity(self, messages: List[Dict[str, Any]]) -> float:
total_length = sum(len(msg.get('content', '')) for msg in messages)
avg_length = total_length / len(messages) if messages else 0
unique_words = set()
for msg in messages:
content = msg.get('content', '')
words = re.findall(r'\b\w+\b', content.lower())
unique_words.update(words)
vocabulary_richness = len(unique_words) / total_length if total_length > 0 else 0
# Simple complexity score based on length and richness
complexity = min(1.0, (avg_length / 100) + vocabulary_richness)
return complexity
def extract_key_sentences(self, text: str, top_k: int = 5) -> List[str]:
sentences = re.split(r'(?<=[.!?])\s+', text)
if not sentences:
return []
# Simple scoring based on length and position
scored_sentences = []
for i, sentence in enumerate(sentences):
length_score = min(1.0, len(sentence) / 50)
position_score = 1.0 if i == 0 else 0.8 if i < len(sentences) / 2 else 0.6
score = (length_score + position_score) / 2
scored_sentences.append((sentence, score))
scored_sentences.sort(key=lambda x: x[1], reverse=True)
return [s[0] for s in scored_sentences[:top_k]]
def advanced_summarize_messages(self, messages: List[Dict[str, Any]]) -> str:
all_content = ' '.join([msg.get('content', '') for msg in messages])
key_sentences = self.extract_key_sentences(all_content, top_k=3)
summary = ' '.join(key_sentences)
return summary if summary else "No content to summarize."
def score_message_relevance(self, message: Dict[str, Any], context: str) -> float:
content = message.get('content', '')
content_words = set(re.findall(r'\b\w+\b', content.lower()))
context_words = set(re.findall(r'\b\w+\b', context.lower()))
intersection = content_words & context_words
union = content_words | context_words
if not union:
return 0.0
return len(intersection) / len(union)

95
pr/core/api.py Normal file
View File

@ -0,0 +1,95 @@
import json
import urllib.request
import urllib.error
import logging
from pr.config import DEFAULT_TEMPERATURE, DEFAULT_MAX_TOKENS
from pr.core.context import auto_slim_messages
logger = logging.getLogger('pr')
def call_api(messages, model, api_url, api_key, use_tools, tools_definition, verbose=False):
try:
messages = auto_slim_messages(messages, verbose=verbose)
logger.debug(f"=== API CALL START ===")
logger.debug(f"Model: {model}")
logger.debug(f"API URL: {api_url}")
logger.debug(f"Use tools: {use_tools}")
logger.debug(f"Message count: {len(messages)}")
headers = {
'Content-Type': 'application/json',
}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
data = {
'model': model,
'messages': messages,
'temperature': DEFAULT_TEMPERATURE,
'max_tokens': DEFAULT_MAX_TOKENS
}
if "gpt-5" in model:
del data['temperature']
del data['max_tokens']
logger.debug("GPT-5 detected: removed temperature and max_tokens")
if use_tools:
data['tools'] = tools_definition
data['tool_choice'] = 'auto'
logger.debug(f"Tool calling enabled with {len(tools_definition)} tools")
request_json = json.dumps(data)
logger.debug(f"Request payload size: {len(request_json)} bytes")
req = urllib.request.Request(
api_url,
data=request_json.encode('utf-8'),
headers=headers,
method='POST'
)
logger.debug("Sending HTTP request...")
with urllib.request.urlopen(req) as response:
response_data = response.read().decode('utf-8')
logger.debug(f"Response received: {len(response_data)} bytes")
result = json.loads(response_data)
if 'usage' in result:
logger.debug(f"Token usage: {result['usage']}")
if 'choices' in result and result['choices']:
choice = result['choices'][0]
if 'message' in choice:
msg = choice['message']
logger.debug(f"Response role: {msg.get('role', 'N/A')}")
if 'content' in msg and msg['content']:
logger.debug(f"Response content length: {len(msg['content'])} chars")
if 'tool_calls' in msg:
logger.debug(f"Response contains {len(msg['tool_calls'])} tool call(s)")
logger.debug("=== API CALL END ===")
return result
except urllib.error.HTTPError as e:
error_body = e.read().decode('utf-8')
logger.error(f"API HTTP Error: {e.code} - {error_body}")
logger.debug("=== API CALL FAILED ===")
return {"error": f"API Error: {e.code}", "message": error_body}
except Exception as e:
logger.error(f"API call failed: {e}")
logger.debug("=== API CALL FAILED ===")
return {"error": str(e)}
def list_models(model_list_url, api_key):
try:
req = urllib.request.Request(model_list_url)
if api_key:
req.add_header('Authorization', f'Bearer {api_key}')
with urllib.request.urlopen(req) as response:
data = json.loads(response.read().decode('utf-8'))
return data.get('data', [])
except Exception as e:
return {"error": str(e)}

325
pr/core/assistant.py Normal file
View File

@ -0,0 +1,325 @@
import os
import sys
import json
import sqlite3
import signal
import logging
import traceback
import readline
import glob as glob_module
from concurrent.futures import ThreadPoolExecutor
from pr.config import DB_PATH, LOG_FILE, DEFAULT_MODEL, DEFAULT_API_URL, MODEL_LIST_URL, HISTORY_FILE
from pr.ui import Colors, render_markdown
from pr.core.context import init_system_message, truncate_tool_result
from pr.core.api import call_api
from pr.tools import (
http_fetch, run_command, run_command_interactive, read_file, write_file,
list_directory, mkdir, chdir, getpwd, db_set, db_get, db_query,
web_search, web_search_news, python_exec, index_source_directory,
open_editor, editor_insert_text, editor_replace_text, editor_search,
search_replace,close_editor,create_diff,apply_patch,
tail_process, kill_process
)
from pr.tools.patch import display_file_diff
from pr.tools.filesystem import display_edit_summary, display_edit_timeline, clear_edit_tracker
from pr.tools.base import get_tools_definition
from pr.commands import handle_command
logger = logging.getLogger('pr')
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler(LOG_FILE)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)
class Assistant:
def __init__(self, args):
self.args = args
self.messages = []
self.verbose = args.verbose
self.debug = getattr(args, 'debug', False)
self.syntax_highlighting = not args.no_syntax
if self.debug:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
logger.addHandler(console_handler)
logger.debug("Debug mode enabled")
self.api_key = os.environ.get('OPENROUTER_API_KEY', '')
self.model = args.model or os.environ.get('AI_MODEL', DEFAULT_MODEL)
self.api_url = args.api_url or os.environ.get('API_URL', DEFAULT_API_URL)
self.model_list_url = args.model_list_url or os.environ.get('MODEL_LIST_URL', MODEL_LIST_URL)
self.use_tools = os.environ.get('USE_TOOLS', '1') == '1'
self.strict_mode = os.environ.get('STRICT_MODE', '0') == '1'
self.interrupt_count = 0
self.python_globals = {}
self.db_conn = None
self.autonomous_mode = False
self.autonomous_iterations = 0
self.init_database()
self.messages.append(init_system_message(args))
try:
from pr.core.enhanced_assistant import EnhancedAssistant
self.enhanced = EnhancedAssistant(self)
if self.debug:
logger.debug("Enhanced assistant features initialized")
except Exception as e:
logger.warning(f"Could not initialize enhanced features: {e}")
self.enhanced = None
def init_database(self):
try:
logger.debug(f"Initializing database at {DB_PATH}")
self.db_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
cursor = self.db_conn.cursor()
cursor.execute('''CREATE TABLE IF NOT EXISTS kv_store
(key TEXT PRIMARY KEY, value TEXT, timestamp REAL)''')
cursor.execute('''CREATE TABLE IF NOT EXISTS file_versions
(id INTEGER PRIMARY KEY AUTOINCREMENT,
filepath TEXT, content TEXT, hash TEXT,
timestamp REAL, version INTEGER)''')
self.db_conn.commit()
logger.debug("Database initialized successfully")
except Exception as e:
logger.error(f"Database initialization error: {e}")
self.db_conn = None
def execute_tool_calls(self, tool_calls):
results = []
logger.debug(f"Executing {len(tool_calls)} tool call(s)")
with ThreadPoolExecutor(max_workers=5) as executor:
futures = []
for tool_call in tool_calls:
func_name = tool_call['function']['name']
arguments = json.loads(tool_call['function']['arguments'])
logger.debug(f"Tool call: {func_name} with arguments: {arguments}")
func_map = {
'http_fetch': lambda **kw: http_fetch(**kw),
'run_command': lambda **kw: run_command(**kw),
'tail_process': lambda **kw: tail_process(**kw),
'kill_process': lambda **kw: kill_process(**kw),
'run_command_interactive': lambda **kw: run_command_interactive(**kw),
'read_file': lambda **kw: read_file(**kw, db_conn=self.db_conn),
'write_file': lambda **kw: write_file(**kw, db_conn=self.db_conn),
'list_directory': lambda **kw: list_directory(**kw),
'mkdir': lambda **kw: mkdir(**kw),
'chdir': lambda **kw: chdir(**kw),
'getpwd': lambda **kw: getpwd(**kw),
'db_set': lambda **kw: db_set(**kw, db_conn=self.db_conn),
'db_get': lambda **kw: db_get(**kw, db_conn=self.db_conn),
'db_query': lambda **kw: db_query(**kw, db_conn=self.db_conn),
'web_search': lambda **kw: web_search(**kw),
'web_search_news': lambda **kw: web_search_news(**kw),
'python_exec': lambda **kw: python_exec(**kw, python_globals=self.python_globals),
'index_source_directory': lambda **kw: index_source_directory(**kw),
'search_replace': lambda **kw: search_replace(**kw, db_conn=self.db_conn),
'open_editor': lambda **kw: open_editor(**kw),
'editor_insert_text': lambda **kw: editor_insert_text(**kw, db_conn=self.db_conn),
'editor_replace_text': lambda **kw: editor_replace_text(**kw, db_conn=self.db_conn),
'editor_search': lambda **kw: editor_search(**kw),
'close_editor': lambda **kw: close_editor(**kw),
'create_diff': lambda **kw: create_diff(**kw),
'apply_patch': lambda **kw: apply_patch(**kw, db_conn=self.db_conn),
'display_file_diff': lambda **kw: display_file_diff(**kw),
'display_edit_summary': lambda **kw: display_edit_summary(),
'display_edit_timeline': lambda **kw: display_edit_timeline(**kw),
'clear_edit_tracker': lambda **kw: clear_edit_tracker(),
}
if func_name in func_map:
future = executor.submit(func_map[func_name], **arguments)
futures.append((tool_call['id'], future))
for tool_id, future in futures:
try:
result = future.result(timeout=30)
result = truncate_tool_result(result)
logger.debug(f"Tool result for {tool_id}: {str(result)[:200]}...")
results.append({
"tool_call_id": tool_id,
"role": "tool",
"content": json.dumps(result)
})
except Exception as e:
logger.debug(f"Tool error for {tool_id}: {str(e)}")
error_msg = str(e)[:200] if len(str(e)) > 200 else str(e)
results.append({
"tool_call_id": tool_id,
"role": "tool",
"content": json.dumps({"status": "error", "error": error_msg})
})
return results
def process_response(self, response):
if 'error' in response:
return f"Error: {response['error']}"
if 'choices' not in response or not response['choices']:
return "No response from API"
message = response['choices'][0]['message']
self.messages.append(message)
if 'tool_calls' in message and message['tool_calls']:
if self.verbose:
print(f"{Colors.YELLOW}Executing tool calls...{Colors.RESET}")
tool_results = self.execute_tool_calls(message['tool_calls'])
for result in tool_results:
self.messages.append(result)
follow_up = call_api(
self.messages, self.model, self.api_url, self.api_key,
self.use_tools, get_tools_definition(), verbose=self.verbose
)
return self.process_response(follow_up)
content = message.get('content', '')
return render_markdown(content, self.syntax_highlighting)
def signal_handler(self, signum, frame):
if self.autonomous_mode:
self.interrupt_count += 1
if self.interrupt_count >= 2:
print(f"\n{Colors.RED}Force exiting autonomous mode...{Colors.RESET}")
self.autonomous_mode = False
sys.exit(0)
else:
print(f"\n{Colors.YELLOW}Press Ctrl+C again to force exit{Colors.RESET}")
return
self.interrupt_count += 1
if self.interrupt_count >= 2:
print(f"\n{Colors.RED}Exiting...{Colors.RESET}")
self.cleanup()
sys.exit(0)
else:
print(f"\n{Colors.YELLOW}Press Ctrl+C again to exit{Colors.RESET}")
def setup_readline(self):
try:
readline.read_history_file(HISTORY_FILE)
except FileNotFoundError:
pass
readline.set_history_length(1000)
import atexit
atexit.register(readline.write_history_file, HISTORY_FILE)
commands = ['exit', 'quit', 'help', 'reset', 'dump', 'verbose',
'models', 'tools', 'review', 'refactor', 'obfuscate', '/auto']
def completer(text, state):
options = [cmd for cmd in commands if cmd.startswith(text)]
glob_pattern = os.path.expanduser(text) + '*'
path_options = glob_module.glob(glob_pattern)
path_options = [p + os.sep if os.path.isdir(p) else p for p in path_options]
combined_options = sorted(list(set(options + path_options)))
if state < len(combined_options):
return combined_options[state]
return None
delims = readline.get_completer_delims()
readline.set_completer_delims(delims.replace('/', ''))
readline.set_completer(completer)
readline.parse_and_bind('tab: complete')
def run_repl(self):
self.setup_readline()
signal.signal(signal.SIGINT, self.signal_handler)
print(f"{Colors.BOLD}r{Colors.RESET}")
print(f"Type 'help' for commands or start chatting")
while True:
try:
user_input = input(f"{Colors.BLUE}You>{Colors.RESET} ").strip()
if not user_input:
continue
cmd_result = handle_command(self, user_input)
if cmd_result is False:
break
elif cmd_result is True:
continue
process_message(self, user_input)
except EOFError:
break
except KeyboardInterrupt:
self.signal_handler(None, None)
except Exception as e:
print(f"{Colors.RED}Error: {e}{Colors.RESET}")
logging.error(f"REPL error: {e}\n{traceback.format_exc()}")
def run_single(self):
if self.args.message:
message = self.args.message
else:
message = sys.stdin.read()
process_message(self, message)
def cleanup(self):
if hasattr(self, 'enhanced') and self.enhanced:
try:
self.enhanced.cleanup()
except Exception as e:
logger.error(f"Error cleaning up enhanced features: {e}")
try:
from pr.multiplexer import cleanup_all_multiplexers
cleanup_all_multiplexers()
except Exception as e:
logger.error(f"Error cleaning up multiplexers: {e}")
if self.db_conn:
self.db_conn.close()
def run(self):
try:
if self.args.interactive or (not self.args.message and sys.stdin.isatty()):
self.run_repl()
else:
self.run_single()
finally:
self.cleanup()
def process_message(assistant, message):
assistant.messages.append({"role": "user", "content": message})
logger.debug(f"Processing user message: {message[:100]}...")
logger.debug(f"Current message count: {len(assistant.messages)}")
if assistant.verbose:
print(f"{Colors.GRAY}Sending request to API...{Colors.RESET}")
response = call_api(
assistant.messages, assistant.model, assistant.api_url,
assistant.api_key, assistant.use_tools, get_tools_definition(),
verbose=assistant.verbose
)
result = assistant.process_response(response)
print(f"\n{Colors.GREEN}r:{Colors.RESET} {result}\n")

108
pr/core/config_loader.py Normal file
View File

@ -0,0 +1,108 @@
import os
import configparser
from typing import Dict, Any
from pr.core.logging import get_logger
logger = get_logger('config')
CONFIG_FILE = os.path.expanduser("~/.prrc")
LOCAL_CONFIG_FILE = ".prrc"
def load_config() -> Dict[str, Any]:
config = {
'api': {},
'autonomous': {},
'ui': {},
'output': {},
'session': {}
}
global_config = _load_config_file(CONFIG_FILE)
local_config = _load_config_file(LOCAL_CONFIG_FILE)
for section in config.keys():
if section in global_config:
config[section].update(global_config[section])
if section in local_config:
config[section].update(local_config[section])
return config
def _load_config_file(filepath: str) -> Dict[str, Dict[str, Any]]:
if not os.path.exists(filepath):
return {}
try:
parser = configparser.ConfigParser()
parser.read(filepath)
config = {}
for section in parser.sections():
config[section] = {}
for key, value in parser.items(section):
config[section][key] = _parse_value(value)
logger.debug(f"Loaded configuration from {filepath}")
return config
except Exception as e:
logger.error(f"Error loading config from {filepath}: {e}")
return {}
def _parse_value(value: str) -> Any:
value = value.strip()
if value.lower() == 'true':
return True
if value.lower() == 'false':
return False
if value.isdigit():
return int(value)
try:
return float(value)
except ValueError:
pass
return value
def create_default_config(filepath: str = CONFIG_FILE):
default_config = """[api]
default_model = x-ai/grok-code-fast-1
timeout = 30
temperature = 0.7
max_tokens = 8096
[autonomous]
max_iterations = 50
context_threshold = 30
recent_messages_to_keep = 10
[ui]
syntax_highlighting = true
show_timestamps = false
color_output = true
[output]
format = text
verbose = false
quiet = false
[session]
auto_save = false
max_history = 1000
"""
try:
with open(filepath, 'w') as f:
f.write(default_config)
logger.info(f"Created default configuration at {filepath}")
return True
except Exception as e:
logger.error(f"Error creating config file: {e}")
return False

289
pr/core/context.py Normal file
View File

@ -0,0 +1,289 @@
import os
import json
import logging
from pr.config import (CONTEXT_FILE, GLOBAL_CONTEXT_FILE, CONTEXT_COMPRESSION_THRESHOLD,
RECENT_MESSAGES_TO_KEEP, MAX_TOKENS_LIMIT, CHARS_PER_TOKEN,
EMERGENCY_MESSAGES_TO_KEEP, CONTENT_TRIM_LENGTH, MAX_TOOL_RESULT_LENGTH)
from pr.ui import Colors
def truncate_tool_result(result, max_length=None):
if max_length is None:
max_length = MAX_TOOL_RESULT_LENGTH
if not isinstance(result, dict):
return result
result_copy = result.copy()
if "output" in result_copy and isinstance(result_copy["output"], str):
if len(result_copy["output"]) > max_length:
result_copy["output"] = result_copy["output"][:max_length] + f"\n... [truncated {len(result_copy['output']) - max_length} chars]"
if "content" in result_copy and isinstance(result_copy["content"], str):
if len(result_copy["content"]) > max_length:
result_copy["content"] = result_copy["content"][:max_length] + f"\n... [truncated {len(result_copy['content']) - max_length} chars]"
if "data" in result_copy and isinstance(result_copy["data"], str):
if len(result_copy["data"]) > max_length:
result_copy["data"] = result_copy["data"][:max_length] + f"\n... [truncated]"
if "error" in result_copy and isinstance(result_copy["error"], str):
if len(result_copy["error"]) > max_length // 2:
result_copy["error"] = result_copy["error"][:max_length // 2] + "... [truncated]"
return result_copy
def init_system_message(args):
context_parts = ["""You are a professional AI assistant with access to advanced tools.
File Operations:
- Use RPEditor tools (open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor) for precise file modifications
- Always close editor files when finished
- Use write_file for complete file rewrites, search_replace for simple text replacements
Process Management:
- run_command executes shell commands with a timeout (default 30s)
- If a command times out, you receive a PID in the response
- Use tail_process(pid) to monitor running processes
- Use kill_process(pid) to terminate processes
- Manage long-running commands effectively using these tools
Shell Commands:
- Be a shell ninja using native OS tools
- Prefer standard Unix utilities over complex scripts
- Use run_command_interactive for commands requiring user input (vim, nano, etc.)"""]
#context_parts = ["You are a helpful AI assistant with access to advanced tools, including a powerful built-in editor (RPEditor). For file editing tasks, prefer using the editor-related tools like write_file, search_replace, open_editor, editor_insert_text, editor_replace_text, and editor_search, as they provide advanced editing capabilities with undo/redo, search, and precise text manipulation. The editor is integrated seamlessly and should be your primary tool for modifying files."]
max_context_size = 10000
if args.include_env:
env_context = "Environment Variables:\n"
for key, value in os.environ.items():
if not key.startswith('_'):
env_context += f"{key}={value}\n"
if len(env_context) > max_context_size:
env_context = env_context[:max_context_size] + "\n... [truncated]"
context_parts.append(env_context)
for context_file in [CONTEXT_FILE, GLOBAL_CONTEXT_FILE]:
if os.path.exists(context_file):
try:
with open(context_file, 'r') as f:
content = f.read()
if len(content) > max_context_size:
content = content[:max_context_size] + "\n... [truncated]"
context_parts.append(f"Context from {context_file}:\n{content}")
except Exception as e:
logging.error(f"Error reading context file {context_file}: {e}")
if args.context:
for ctx_file in args.context:
try:
with open(ctx_file, 'r') as f:
content = f.read()
if len(content) > max_context_size:
content = content[:max_context_size] + "\n... [truncated]"
context_parts.append(f"Context from {ctx_file}:\n{content}")
except Exception as e:
logging.error(f"Error reading context file {ctx_file}: {e}")
system_message = "\n\n".join(context_parts)
if len(system_message) > max_context_size * 3:
system_message = system_message[:max_context_size * 3] + "\n... [system message truncated]"
return {"role": "system", "content": system_message}
def should_compress_context(messages):
return len(messages) > CONTEXT_COMPRESSION_THRESHOLD
def compress_context(messages):
return manage_context_window(messages, verbose=False)
def manage_context_window(messages, verbose):
if len(messages) <= CONTEXT_COMPRESSION_THRESHOLD:
return messages
if verbose:
print(f"{Colors.YELLOW}📄 Managing context window (current: {len(messages)} messages)...{Colors.RESET}")
system_message = messages[0]
recent_messages = messages[-RECENT_MESSAGES_TO_KEEP:]
middle_messages = messages[1:-RECENT_MESSAGES_TO_KEEP]
if middle_messages:
summary = summarize_messages(middle_messages)
summary_message = {
"role": "system",
"content": f"[Previous conversation summary: {summary}]"
}
new_messages = [system_message, summary_message] + recent_messages
if verbose:
print(f"{Colors.GREEN}✓ Context compressed to {len(new_messages)} messages{Colors.RESET}")
return new_messages
return messages
def summarize_messages(messages):
summary_parts = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
if role == "tool":
continue
if isinstance(content, str) and len(content) > 200:
content = content[:200] + "..."
summary_parts.append(f"{role}: {content}")
return " | ".join(summary_parts[:10])
def estimate_tokens(messages):
total_chars = 0
for msg in messages:
msg_json = json.dumps(msg)
total_chars += len(msg_json)
estimated_tokens = total_chars / CHARS_PER_TOKEN
overhead_multiplier = 1.3
return int(estimated_tokens * overhead_multiplier)
def trim_message_content(message, max_length):
trimmed_msg = message.copy()
if "content" in trimmed_msg:
content = trimmed_msg["content"]
if isinstance(content, str) and len(content) > max_length:
trimmed_msg["content"] = content[:max_length] + f"\n... [trimmed {len(content) - max_length} chars]"
elif isinstance(content, list):
trimmed_content = []
for item in content:
if isinstance(item, dict):
trimmed_item = item.copy()
if "text" in trimmed_item and len(trimmed_item["text"]) > max_length:
trimmed_item["text"] = trimmed_item["text"][:max_length] + f"\n... [trimmed]"
trimmed_content.append(trimmed_item)
else:
trimmed_content.append(item)
trimmed_msg["content"] = trimmed_content
if trimmed_msg.get("role") == "tool":
if "content" in trimmed_msg and isinstance(trimmed_msg["content"], str):
content = trimmed_msg["content"]
if len(content) > MAX_TOOL_RESULT_LENGTH:
trimmed_msg["content"] = content[:MAX_TOOL_RESULT_LENGTH] + f"\n... [trimmed {len(content) - MAX_TOOL_RESULT_LENGTH} chars]"
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
if "output" in parsed and isinstance(parsed["output"], str) and len(parsed["output"]) > MAX_TOOL_RESULT_LENGTH // 2:
parsed["output"] = parsed["output"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
if "content" in parsed and isinstance(parsed["content"], str) and len(parsed["content"]) > MAX_TOOL_RESULT_LENGTH // 2:
parsed["content"] = parsed["content"][:MAX_TOOL_RESULT_LENGTH // 2] + f"\n... [truncated]"
trimmed_msg["content"] = json.dumps(parsed)
except:
pass
return trimmed_msg
def intelligently_trim_messages(messages, target_tokens, keep_recent=3):
if estimate_tokens(messages) <= target_tokens:
return messages
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
start_idx = 1 if system_msg else 0
recent_messages = messages[-keep_recent:] if len(messages) > keep_recent else messages[start_idx:]
middle_messages = messages[start_idx:-keep_recent] if len(messages) > keep_recent else []
trimmed_middle = []
for msg in middle_messages:
if msg.get("role") == "tool":
trimmed_middle.append(trim_message_content(msg, MAX_TOOL_RESULT_LENGTH // 2))
elif msg.get("role") in ["user", "assistant"]:
trimmed_middle.append(trim_message_content(msg, CONTENT_TRIM_LENGTH))
else:
trimmed_middle.append(msg)
result = ([system_msg] if system_msg else []) + trimmed_middle + recent_messages
if estimate_tokens(result) <= target_tokens:
return result
step_size = len(trimmed_middle) // 4 if len(trimmed_middle) >= 4 else 1
while len(trimmed_middle) > 0 and estimate_tokens(result) > target_tokens:
remove_count = min(step_size, len(trimmed_middle))
trimmed_middle = trimmed_middle[remove_count:]
result = ([system_msg] if system_msg else []) + trimmed_middle + recent_messages
if estimate_tokens(result) <= target_tokens:
return result
keep_recent -= 1
if keep_recent > 0:
return intelligently_trim_messages(messages, target_tokens, keep_recent)
return ([system_msg] if system_msg else []) + messages[-1:]
def auto_slim_messages(messages, verbose=False):
estimated_tokens = estimate_tokens(messages)
if estimated_tokens <= MAX_TOKENS_LIMIT:
return messages
if verbose:
print(f"{Colors.YELLOW}⚠️ Token limit approaching: ~{estimated_tokens} tokens (limit: {MAX_TOKENS_LIMIT}){Colors.RESET}")
print(f"{Colors.YELLOW}🔧 Intelligently trimming message content...{Colors.RESET}")
result = intelligently_trim_messages(messages, MAX_TOKENS_LIMIT, keep_recent=EMERGENCY_MESSAGES_TO_KEEP)
final_tokens = estimate_tokens(result)
if final_tokens > MAX_TOKENS_LIMIT:
if verbose:
print(f"{Colors.RED}⚠️ Still over limit after trimming, applying emergency reduction...{Colors.RESET}")
result = emergency_reduce_messages(result, MAX_TOKENS_LIMIT, verbose)
final_tokens = estimate_tokens(result)
if verbose:
removed_count = len(messages) - len(result)
print(f"{Colors.GREEN}✓ Optimized from {len(messages)} to {len(result)} messages{Colors.RESET}")
print(f"{Colors.GREEN} Token estimate: {estimated_tokens}{final_tokens} (~{estimated_tokens - final_tokens} saved){Colors.RESET}")
if removed_count > 0:
print(f"{Colors.GREEN} Removed {removed_count} older messages{Colors.RESET}")
return result
def emergency_reduce_messages(messages, target_tokens, verbose=False):
system_msg = messages[0] if messages and messages[0].get("role") == "system" else None
start_idx = 1 if system_msg else 0
keep_count = 2
while estimate_tokens(messages) > target_tokens and keep_count >= 1:
if len(messages[start_idx:]) <= keep_count:
break
result = ([system_msg] if system_msg else []) + messages[-keep_count:]
for i in range(len(result)):
result[i] = trim_message_content(result[i], CONTENT_TRIM_LENGTH // 2)
if estimate_tokens(result) <= target_tokens:
return result
keep_count -= 1
final = ([system_msg] if system_msg else []) + messages[-1:]
for i in range(len(final)):
if final[i].get("role") != "system":
final[i] = trim_message_content(final[i], 100)
return final

View File

@ -0,0 +1,278 @@
import logging
import json
import uuid
from typing import Optional, Dict, Any, List
from pr.config import (
DB_PATH, CACHE_ENABLED, API_CACHE_TTL, TOOL_CACHE_TTL,
WORKFLOW_EXECUTOR_MAX_WORKERS, AGENT_MAX_WORKERS,
KNOWLEDGE_SEARCH_LIMIT, ADVANCED_CONTEXT_ENABLED,
MEMORY_AUTO_SUMMARIZE, CONVERSATION_SUMMARY_THRESHOLD
)
from pr.cache import APICache, ToolCache
from pr.workflows import WorkflowEngine, WorkflowStorage
from pr.agents import AgentManager
from pr.memory import KnowledgeStore, ConversationMemory, FactExtractor
from pr.core.advanced_context import AdvancedContextManager
from pr.core.api import call_api
from pr.tools.base import get_tools_definition
logger = logging.getLogger('pr')
class EnhancedAssistant:
def __init__(self, base_assistant):
self.base = base_assistant
if CACHE_ENABLED:
self.api_cache = APICache(DB_PATH, API_CACHE_TTL)
self.tool_cache = ToolCache(DB_PATH, TOOL_CACHE_TTL)
else:
self.api_cache = None
self.tool_cache = None
self.workflow_storage = WorkflowStorage(DB_PATH)
self.workflow_engine = WorkflowEngine(
tool_executor=self._execute_tool_for_workflow,
max_workers=WORKFLOW_EXECUTOR_MAX_WORKERS
)
self.agent_manager = AgentManager(DB_PATH, self._api_caller_for_agent)
self.knowledge_store = KnowledgeStore(DB_PATH)
self.conversation_memory = ConversationMemory(DB_PATH)
self.fact_extractor = FactExtractor()
if ADVANCED_CONTEXT_ENABLED:
self.context_manager = AdvancedContextManager(
knowledge_store=self.knowledge_store,
conversation_memory=self.conversation_memory
)
else:
self.context_manager = None
self.current_conversation_id = str(uuid.uuid4())[:16]
self.conversation_memory.create_conversation(
self.current_conversation_id,
session_id=str(uuid.uuid4())[:16]
)
logger.info("Enhanced Assistant initialized with all features")
def _execute_tool_for_workflow(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
if self.tool_cache:
cached_result = self.tool_cache.get(tool_name, arguments)
if cached_result is not None:
logger.debug(f"Tool cache hit for {tool_name}")
return cached_result
func_map = {
'read_file': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'read_file', 'arguments': json.dumps(kw)}
}])[0],
'write_file': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'write_file', 'arguments': json.dumps(kw)}
}])[0],
'list_directory': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'list_directory', 'arguments': json.dumps(kw)}
}])[0],
'run_command': lambda **kw: self.base.execute_tool_calls([{
'id': 'temp',
'function': {'name': 'run_command', 'arguments': json.dumps(kw)}
}])[0],
}
if tool_name in func_map:
result = func_map[tool_name](**arguments)
if self.tool_cache:
content = result.get('content', '')
try:
parsed_content = json.loads(content) if isinstance(content, str) else content
self.tool_cache.set(tool_name, arguments, parsed_content)
except Exception:
pass
return result
return {'error': f'Unknown tool: {tool_name}'}
def _api_caller_for_agent(self, messages: List[Dict[str, Any]],
temperature: float, max_tokens: int) -> Dict[str, Any]:
return call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
use_tools=False,
tools=None,
temperature=temperature,
max_tokens=max_tokens,
verbose=self.base.verbose
)
def enhanced_call_api(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]:
if self.api_cache and CACHE_ENABLED:
cached_response = self.api_cache.get(
self.base.model, messages,
0.7, 4096
)
if cached_response:
logger.debug("API cache hit")
return cached_response
response = call_api(
messages,
self.base.model,
self.base.api_url,
self.base.api_key,
self.base.use_tools,
get_tools_definition(),
verbose=self.base.verbose
)
if self.api_cache and CACHE_ENABLED and 'error' not in response:
token_count = response.get('usage', {}).get('total_tokens', 0)
self.api_cache.set(
self.base.model, messages,
0.7, 4096,
response, token_count
)
return response
def process_with_enhanced_context(self, user_message: str) -> str:
self.base.messages.append({"role": "user", "content": user_message})
self.conversation_memory.add_message(
self.current_conversation_id,
str(uuid.uuid4())[:16],
'user',
user_message
)
if MEMORY_AUTO_SUMMARIZE and len(self.base.messages) % 5 == 0:
facts = self.fact_extractor.extract_facts(user_message)
for fact in facts[:3]:
entry_id = str(uuid.uuid4())[:16]
from pr.memory import KnowledgeEntry
import time
categories = self.fact_extractor.categorize_content(fact['text'])
entry = KnowledgeEntry(
entry_id=entry_id,
category=categories[0] if categories else 'general',
content=fact['text'],
metadata={'type': fact['type'], 'confidence': fact['confidence']},
created_at=time.time(),
updated_at=time.time()
)
self.knowledge_store.add_entry(entry)
if self.context_manager and ADVANCED_CONTEXT_ENABLED:
enhanced_messages, context_info = self.context_manager.create_enhanced_context(
self.base.messages,
user_message,
include_knowledge=True
)
if self.base.verbose:
logger.info(f"Enhanced context: {context_info}")
working_messages = enhanced_messages
else:
working_messages = self.base.messages
response = self.enhanced_call_api(working_messages)
result = self.base.process_response(response)
if len(self.base.messages) >= CONVERSATION_SUMMARY_THRESHOLD:
summary = self.context_manager.advanced_summarize_messages(
self.base.messages[-CONVERSATION_SUMMARY_THRESHOLD:]
) if self.context_manager else "Conversation in progress"
topics = self.fact_extractor.categorize_content(summary)
self.conversation_memory.update_conversation_summary(
self.current_conversation_id,
summary,
topics
)
return result
def execute_workflow(self, workflow_name: str,
initial_variables: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
workflow = self.workflow_storage.load_workflow_by_name(workflow_name)
if not workflow:
return {'error': f'Workflow "{workflow_name}" not found'}
context = self.workflow_engine.execute_workflow(workflow, initial_variables)
execution_id = self.workflow_storage.save_execution(
self.workflow_storage.load_workflow_by_name(workflow_name).name,
context
)
return {
'success': True,
'execution_id': execution_id,
'results': context.step_results,
'execution_log': context.execution_log
}
def create_agent(self, role_name: str, agent_id: Optional[str] = None) -> str:
return self.agent_manager.create_agent(role_name, agent_id)
def agent_task(self, agent_id: str, task: str) -> Dict[str, Any]:
return self.agent_manager.execute_agent_task(agent_id, task)
def collaborate_agents(self, task: str, agent_roles: List[str]) -> Dict[str, Any]:
orchestrator_id = self.agent_manager.create_agent('orchestrator')
return self.agent_manager.collaborate_agents(orchestrator_id, task, agent_roles)
def search_knowledge(self, query: str, limit: int = KNOWLEDGE_SEARCH_LIMIT) -> List[Any]:
return self.knowledge_store.search_entries(query, top_k=limit)
def get_cache_statistics(self) -> Dict[str, Any]:
stats = {}
if self.api_cache:
stats['api_cache'] = self.api_cache.get_statistics()
if self.tool_cache:
stats['tool_cache'] = self.tool_cache.get_statistics()
return stats
def get_workflow_list(self) -> List[Dict[str, Any]]:
return self.workflow_storage.list_workflows()
def get_agent_summary(self) -> Dict[str, Any]:
return self.agent_manager.get_session_summary()
def get_knowledge_statistics(self) -> Dict[str, Any]:
return self.knowledge_store.get_statistics()
def get_conversation_history(self, limit: int = 10) -> List[Dict[str, Any]]:
return self.conversation_memory.get_recent_conversations(limit=limit)
def clear_caches(self):
if self.api_cache:
self.api_cache.clear_all()
if self.tool_cache:
self.tool_cache.clear_all()
logger.info("All caches cleared")
def cleanup(self):
if self.api_cache:
self.api_cache.clear_expired()
if self.tool_cache:
self.tool_cache.clear_expired()
self.agent_manager.clear_session()

44
pr/core/exceptions.py Normal file
View File

@ -0,0 +1,44 @@
class PRException(Exception):
pass
class APIException(PRException):
pass
class APIConnectionError(APIException):
pass
class APITimeoutError(APIException):
pass
class APIResponseError(APIException):
pass
class ConfigurationError(PRException):
pass
class ToolExecutionError(PRException):
def __init__(self, tool_name: str, message: str):
self.tool_name = tool_name
super().__init__(f"Error executing tool '{tool_name}': {message}")
class FileSystemError(PRException):
pass
class SessionError(PRException):
pass
class ContextError(PRException):
pass
class ValidationError(PRException):
pass

46
pr/core/logging.py Normal file
View File

@ -0,0 +1,46 @@
import logging
import os
from logging.handlers import RotatingFileHandler
from pr.config import LOG_FILE
def setup_logging(verbose=False):
log_dir = os.path.dirname(LOG_FILE)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
logger = logging.getLogger('pr')
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
if logger.handlers:
logger.handlers.clear()
file_handler = RotatingFileHandler(
LOG_FILE,
maxBytes=10 * 1024 * 1024,
backupCount=5
)
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
if verbose:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter(
'%(levelname)s: %(message)s'
)
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
return logger
def get_logger(name=None):
if name:
return logging.getLogger(f'pr.{name}')
return logging.getLogger('pr')

146
pr/core/session.py Normal file
View File

@ -0,0 +1,146 @@
import json
import os
from datetime import datetime
from typing import Dict, List, Optional
from pr.core.logging import get_logger
logger = get_logger('session')
SESSIONS_DIR = os.path.expanduser("~/.assistant_sessions")
class SessionManager:
def __init__(self):
os.makedirs(SESSIONS_DIR, exist_ok=True)
def save_session(self, name: str, messages: List[Dict], metadata: Optional[Dict] = None) -> bool:
try:
session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
session_data = {
'name': name,
'created_at': datetime.now().isoformat(),
'messages': messages,
'metadata': metadata or {}
}
with open(session_file, 'w') as f:
json.dump(session_data, f, indent=2)
logger.info(f"Session saved: {name}")
return True
except Exception as e:
logger.error(f"Error saving session {name}: {e}")
return False
def load_session(self, name: str) -> Optional[Dict]:
try:
session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
if not os.path.exists(session_file):
logger.warning(f"Session not found: {name}")
return None
with open(session_file, 'r') as f:
session_data = json.load(f)
logger.info(f"Session loaded: {name}")
return session_data
except Exception as e:
logger.error(f"Error loading session {name}: {e}")
return None
def list_sessions(self) -> List[Dict]:
sessions = []
try:
for filename in os.listdir(SESSIONS_DIR):
if filename.endswith('.json'):
filepath = os.path.join(SESSIONS_DIR, filename)
try:
with open(filepath, 'r') as f:
data = json.load(f)
sessions.append({
'name': data.get('name', filename[:-5]),
'created_at': data.get('created_at', 'unknown'),
'message_count': len(data.get('messages', [])),
'metadata': data.get('metadata', {})
})
except Exception as e:
logger.warning(f"Error reading session file {filename}: {e}")
sessions.sort(key=lambda x: x['created_at'], reverse=True)
except Exception as e:
logger.error(f"Error listing sessions: {e}")
return sessions
def delete_session(self, name: str) -> bool:
try:
session_file = os.path.join(SESSIONS_DIR, f"{name}.json")
if not os.path.exists(session_file):
logger.warning(f"Session not found: {name}")
return False
os.remove(session_file)
logger.info(f"Session deleted: {name}")
return True
except Exception as e:
logger.error(f"Error deleting session {name}: {e}")
return False
def export_session(self, name: str, output_path: str, format: str = 'json') -> bool:
session_data = self.load_session(name)
if not session_data:
return False
try:
if format == 'json':
with open(output_path, 'w') as f:
json.dump(session_data, f, indent=2)
elif format == 'markdown':
with open(output_path, 'w') as f:
f.write(f"# Session: {name}\n\n")
f.write(f"Created: {session_data['created_at']}\n\n")
f.write("---\n\n")
for msg in session_data['messages']:
role = msg.get('role', 'unknown')
content = msg.get('content', '')
f.write(f"## {role.capitalize()}\n\n")
f.write(f"{content}\n\n")
f.write("---\n\n")
elif format == 'txt':
with open(output_path, 'w') as f:
f.write(f"Session: {name}\n")
f.write(f"Created: {session_data['created_at']}\n")
f.write("=" * 80 + "\n\n")
for msg in session_data['messages']:
role = msg.get('role', 'unknown')
content = msg.get('content', '')
f.write(f"[{role.upper()}]\n")
f.write(f"{content}\n")
f.write("-" * 80 + "\n\n")
else:
logger.error(f"Unsupported export format: {format}")
return False
logger.info(f"Session exported to {output_path}")
return True
except Exception as e:
logger.error(f"Error exporting session: {e}")
return False

162
pr/core/usage_tracker.py Normal file
View File

@ -0,0 +1,162 @@
import json
import os
from datetime import datetime
from typing import Dict, Optional
from pr.core.logging import get_logger
logger = get_logger('usage')
USAGE_DB_FILE = os.path.expanduser("~/.assistant_usage.json")
MODEL_COSTS = {
'x-ai/grok-code-fast-1': {'input': 0.0, 'output': 0.0},
'gpt-4': {'input': 0.03, 'output': 0.06},
'gpt-4-turbo': {'input': 0.01, 'output': 0.03},
'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
'claude-3-opus': {'input': 0.015, 'output': 0.075},
'claude-3-sonnet': {'input': 0.003, 'output': 0.015},
'claude-3-haiku': {'input': 0.00025, 'output': 0.00125},
}
class UsageTracker:
def __init__(self):
self.session_usage = {
'requests': 0,
'total_tokens': 0,
'input_tokens': 0,
'output_tokens': 0,
'estimated_cost': 0.0,
'models_used': {}
}
def track_request(
self,
model: str,
input_tokens: int,
output_tokens: int,
total_tokens: Optional[int] = None
):
if total_tokens is None:
total_tokens = input_tokens + output_tokens
self.session_usage['requests'] += 1
self.session_usage['total_tokens'] += total_tokens
self.session_usage['input_tokens'] += input_tokens
self.session_usage['output_tokens'] += output_tokens
if model not in self.session_usage['models_used']:
self.session_usage['models_used'][model] = {
'requests': 0,
'tokens': 0,
'cost': 0.0
}
model_usage = self.session_usage['models_used'][model]
model_usage['requests'] += 1
model_usage['tokens'] += total_tokens
cost = self._calculate_cost(model, input_tokens, output_tokens)
model_usage['cost'] += cost
self.session_usage['estimated_cost'] += cost
self._save_to_history(model, input_tokens, output_tokens, cost)
logger.debug(
f"Tracked request: {model}, tokens: {total_tokens}, cost: ${cost:.4f}"
)
def _calculate_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
if model not in MODEL_COSTS:
base_model = model.split('/')[0] if '/' in model else model
if base_model not in MODEL_COSTS:
logger.warning(f"Unknown model for cost calculation: {model}")
return 0.0
costs = MODEL_COSTS[base_model]
else:
costs = MODEL_COSTS[model]
input_cost = (input_tokens / 1000) * costs['input']
output_cost = (output_tokens / 1000) * costs['output']
return input_cost + output_cost
def _save_to_history(self, model: str, input_tokens: int, output_tokens: int, cost: float):
try:
history = []
if os.path.exists(USAGE_DB_FILE):
with open(USAGE_DB_FILE, 'r') as f:
history = json.load(f)
history.append({
'timestamp': datetime.now().isoformat(),
'model': model,
'input_tokens': input_tokens,
'output_tokens': output_tokens,
'total_tokens': input_tokens + output_tokens,
'cost': cost
})
if len(history) > 10000:
history = history[-10000:]
with open(USAGE_DB_FILE, 'w') as f:
json.dump(history, f, indent=2)
except Exception as e:
logger.error(f"Error saving usage history: {e}")
def get_session_summary(self) -> Dict:
return self.session_usage.copy()
def get_formatted_summary(self) -> str:
usage = self.session_usage
lines = [
"\n=== Session Usage Summary ===",
f"Total Requests: {usage['requests']}",
f"Total Tokens: {usage['total_tokens']:,}",
f" Input: {usage['input_tokens']:,}",
f" Output: {usage['output_tokens']:,}",
f"Estimated Cost: ${usage['estimated_cost']:.4f}",
]
if usage['models_used']:
lines.append("\nModels Used:")
for model, stats in usage['models_used'].items():
lines.append(
f" {model}: {stats['requests']} requests, "
f"{stats['tokens']:,} tokens, ${stats['cost']:.4f}"
)
return '\n'.join(lines)
@staticmethod
def get_total_usage() -> Dict:
if not os.path.exists(USAGE_DB_FILE):
return {
'total_requests': 0,
'total_tokens': 0,
'total_cost': 0.0
}
try:
with open(USAGE_DB_FILE, 'r') as f:
history = json.load(f)
total_tokens = sum(entry['total_tokens'] for entry in history)
total_cost = sum(entry['cost'] for entry in history)
return {
'total_requests': len(history),
'total_tokens': total_tokens,
'total_cost': total_cost
}
except Exception as e:
logger.error(f"Error loading usage history: {e}")
return {
'total_requests': 0,
'total_tokens': 0,
'total_cost': 0.0
}

86
pr/core/validation.py Normal file
View File

@ -0,0 +1,86 @@
import os
from typing import Optional
from pr.core.exceptions import ValidationError
def validate_file_path(path: str, must_exist: bool = False) -> str:
if not path:
raise ValidationError("File path cannot be empty")
if must_exist and not os.path.exists(path):
raise ValidationError(f"File does not exist: {path}")
if must_exist and os.path.isdir(path):
raise ValidationError(f"Path is a directory, not a file: {path}")
return os.path.abspath(path)
def validate_directory_path(path: str, must_exist: bool = False, create: bool = False) -> str:
if not path:
raise ValidationError("Directory path cannot be empty")
abs_path = os.path.abspath(path)
if must_exist and not os.path.exists(abs_path):
if create:
os.makedirs(abs_path, exist_ok=True)
else:
raise ValidationError(f"Directory does not exist: {abs_path}")
if os.path.exists(abs_path) and not os.path.isdir(abs_path):
raise ValidationError(f"Path is not a directory: {abs_path}")
return abs_path
def validate_model_name(model: str) -> str:
if not model:
raise ValidationError("Model name cannot be empty")
if len(model) < 2:
raise ValidationError("Model name too short")
return model
def validate_api_url(url: str) -> str:
if not url:
raise ValidationError("API URL cannot be empty")
if not url.startswith(('http://', 'https://')):
raise ValidationError("API URL must start with http:// or https://")
return url
def validate_session_name(name: str) -> str:
if not name:
raise ValidationError("Session name cannot be empty")
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
for char in invalid_chars:
if char in name:
raise ValidationError(f"Session name contains invalid character: {char}")
if len(name) > 255:
raise ValidationError("Session name too long (max 255 characters)")
return name
def validate_temperature(temp: float) -> float:
if not 0.0 <= temp <= 2.0:
raise ValidationError("Temperature must be between 0.0 and 2.0")
return temp
def validate_max_tokens(tokens: int) -> int:
if tokens < 1:
raise ValidationError("Max tokens must be at least 1")
if tokens > 100000:
raise ValidationError("Max tokens too high (max 100000)")
return tokens

994
pr/editor.py Normal file
View File

@ -0,0 +1,994 @@
#!/usr/bin/env python3
import curses
import threading
import sys
import os
import re
import socket
import pickle
import queue
import time
import atexit
import signal
import traceback
from contextlib import contextmanager
class RPEditor:
def __init__(self, filename=None, auto_save=False, timeout=30):
"""
Initialize RPEditor with enhanced robustness features.
Args:
filename: File to edit
auto_save: Enable auto-save on exit
timeout: Command timeout in seconds
"""
self.filename = filename
self.lines = [""]
self.cursor_y = 0
self.cursor_x = 0
self.mode = 'normal'
self.command = ""
self.stdscr = None
self.running = False
self.thread = None
self.socket_thread = None
self.prev_key = None
self.clipboard = ""
self.undo_stack = []
self.redo_stack = []
self.selection_start = None
self.selection_end = None
self.max_undo = 100
self.lock = threading.RLock()
self.command_queue = queue.Queue()
self.auto_save = auto_save
self.timeout = timeout
self._cleanup_registered = False
self._original_terminal_state = None
self._exception_occurred = False
# Create socket pair with error handling
try:
self.client_sock, self.server_sock = socket.socketpair()
self.client_sock.settimeout(self.timeout)
self.server_sock.settimeout(self.timeout)
except Exception as e:
self._cleanup()
raise RuntimeError(f"Failed to create socket pair: {e}")
# Register cleanup handlers
self._register_cleanup()
if filename:
self.load_file()
def _register_cleanup(self):
"""Register cleanup handlers for proper shutdown."""
if not self._cleanup_registered:
atexit.register(self._cleanup)
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
self._cleanup_registered = True
def _signal_handler(self, signum, frame):
"""Handle signals for clean shutdown."""
self._cleanup()
sys.exit(0)
def _cleanup(self):
"""Comprehensive cleanup of all resources."""
try:
# Stop the editor
self.running = False
# Save if auto-save is enabled
if self.auto_save and self.filename and not self._exception_occurred:
try:
self._save_file()
except:
pass
# Clean up curses
if self.stdscr:
try:
self.stdscr.keypad(False)
curses.nocbreak()
curses.echo()
curses.curs_set(1)
except:
pass
finally:
try:
curses.endwin()
except:
pass
# Clear screen after curses cleanup
try:
os.system('clear' if os.name != 'nt' else 'cls')
except:
pass
# Close sockets
for sock in [self.client_sock, self.server_sock]:
if sock:
try:
sock.close()
except:
pass
# Wait for threads to finish
for thread in [self.thread, self.socket_thread]:
if thread and thread.is_alive():
thread.join(timeout=1)
except:
pass
def load_file(self):
"""Load file with enhanced error handling."""
try:
if os.path.exists(self.filename):
with open(self.filename, 'r', encoding='utf-8', errors='replace') as f:
content = f.read()
self.lines = content.splitlines() if content else [""]
else:
self.lines = [""]
except Exception as e:
self.lines = [""]
# Don't raise, just use empty content
def _save_file(self):
"""Save file with enhanced error handling and backup."""
with self.lock:
if not self.filename:
return False
try:
# Create backup if file exists
if os.path.exists(self.filename):
backup_name = f"{self.filename}.bak"
try:
with open(self.filename, 'r', encoding='utf-8') as f:
backup_content = f.read()
with open(backup_name, 'w', encoding='utf-8') as f:
f.write(backup_content)
except:
pass # Backup failed, but continue with save
# Save the file
with open(self.filename, 'w', encoding='utf-8') as f:
f.write('\n'.join(self.lines))
return True
except Exception as e:
return False
def save_file(self):
"""Thread-safe save file command."""
try:
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
except:
return self._save_file() # Fallback to direct save
def start(self):
"""Start the editor with enhanced error handling."""
if self.running:
return False
try:
self.running = True
self.socket_thread = threading.Thread(target=self.socket_listener, daemon=True)
self.socket_thread.start()
self.thread = threading.Thread(target=self.run, daemon=True)
self.thread.start()
return True
except Exception as e:
self.running = False
self._cleanup()
raise RuntimeError(f"Failed to start editor: {e}")
def stop(self):
"""Stop the editor with proper cleanup."""
try:
if self.client_sock:
self.client_sock.send(pickle.dumps({'command': 'stop'}))
except:
pass
self.running = False
time.sleep(0.1) # Give threads time to finish
self._cleanup()
def run(self):
"""Run the main editor loop with exception handling."""
try:
curses.wrapper(self.main_loop)
except Exception as e:
self._exception_occurred = True
self._cleanup()
def main_loop(self, stdscr):
"""Main editor loop with enhanced error recovery."""
self.stdscr = stdscr
try:
# Configure curses
curses.curs_set(1)
self.stdscr.keypad(True)
self.stdscr.timeout(100) # Non-blocking with timeout
while self.running:
try:
# Process queued commands
while True:
try:
command = self.command_queue.get_nowait()
with self.lock:
self.execute_command(command)
except queue.Empty:
break
# Draw screen
with self.lock:
self.draw()
# Handle input
try:
key = self.stdscr.getch()
if key != -1: # -1 means timeout/no input
with self.lock:
self.handle_key(key)
except curses.error:
pass # Ignore curses errors
except Exception as e:
# Log error but continue running
pass
except Exception as e:
self._exception_occurred = True
finally:
self._cleanup()
def draw(self):
"""Draw the editor screen with error handling."""
try:
self.stdscr.clear()
height, width = self.stdscr.getmaxyx()
# Draw lines
for i, line in enumerate(self.lines):
if i >= height - 1:
break
try:
# Handle long lines and special characters
display_line = line[:width-1] if len(line) >= width else line
self.stdscr.addstr(i, 0, display_line)
except curses.error:
pass # Skip lines that can't be displayed
# Draw status line
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
if self.mode == 'command':
status = self.command[:width-1]
try:
self.stdscr.addstr(height - 1, 0, status[:width-1])
except curses.error:
pass
# Position cursor
cursor_x = min(self.cursor_x, width - 1)
cursor_y = min(self.cursor_y, height - 2)
try:
self.stdscr.move(cursor_y, cursor_x)
except curses.error:
pass
self.stdscr.refresh()
except Exception:
pass # Continue even if draw fails
def handle_key(self, key):
"""Handle keyboard input with error recovery."""
try:
if self.mode == 'normal':
self.handle_normal(key)
elif self.mode == 'insert':
self.handle_insert(key)
elif self.mode == 'command':
self.handle_command(key)
except Exception:
pass # Continue on error
def handle_normal(self, key):
"""Handle normal mode keys."""
try:
if key == ord('h') or key == curses.KEY_LEFT:
self.move_cursor(0, -1)
elif key == ord('j') or key == curses.KEY_DOWN:
self.move_cursor(1, 0)
elif key == ord('k') or key == curses.KEY_UP:
self.move_cursor(-1, 0)
elif key == ord('l') or key == curses.KEY_RIGHT:
self.move_cursor(0, 1)
elif key == ord('i'):
self.mode = 'insert'
elif key == ord(':'):
self.mode = 'command'
self.command = ":"
elif key == ord('x'):
self._delete_char()
elif key == ord('a'):
self.cursor_x = min(self.cursor_x + 1, len(self.lines[self.cursor_y]))
self.mode = 'insert'
elif key == ord('A'):
self.cursor_x = len(self.lines[self.cursor_y])
self.mode = 'insert'
elif key == ord('o'):
self._insert_line(self.cursor_y + 1, "")
self.cursor_y += 1
self.cursor_x = 0
self.mode = 'insert'
elif key == ord('O'):
self._insert_line(self.cursor_y, "")
self.cursor_x = 0
self.mode = 'insert'
elif key == ord('d') and self.prev_key == ord('d'):
if self.cursor_y < len(self.lines):
self.clipboard = self.lines[self.cursor_y]
self._delete_line(self.cursor_y)
if self.cursor_y >= len(self.lines):
self.cursor_y = max(0, len(self.lines) - 1)
self.cursor_x = 0
elif key == ord('y') and self.prev_key == ord('y'):
if self.cursor_y < len(self.lines):
self.clipboard = self.lines[self.cursor_y]
elif key == ord('p'):
self._insert_line(self.cursor_y + 1, self.clipboard)
self.cursor_y += 1
self.cursor_x = 0
elif key == ord('P'):
self._insert_line(self.cursor_y, self.clipboard)
self.cursor_x = 0
elif key == ord('w'):
self._move_word_forward()
elif key == ord('b'):
self._move_word_backward()
elif key == ord('0'):
self.cursor_x = 0
elif key == ord('$'):
self.cursor_x = len(self.lines[self.cursor_y])
elif key == ord('g'):
if self.prev_key == ord('g'):
self.cursor_y = 0
self.cursor_x = 0
elif key == ord('G'):
self.cursor_y = max(0, len(self.lines) - 1)
self.cursor_x = 0
elif key == ord('u'):
self.undo()
elif key == ord('r') and self.prev_key == 18: # Ctrl-R
self.redo()
self.prev_key = key
except Exception:
pass
def _move_word_forward(self):
"""Move cursor forward by word."""
if self.cursor_y >= len(self.lines):
return
line = self.lines[self.cursor_y]
i = self.cursor_x
# Skip non-alphanumeric
while i < len(line) and not line[i].isalnum():
i += 1
# Skip alphanumeric
while i < len(line) and line[i].isalnum():
i += 1
self.cursor_x = i
def _move_word_backward(self):
"""Move cursor backward by word."""
if self.cursor_y >= len(self.lines):
return
line = self.lines[self.cursor_y]
i = max(0, self.cursor_x - 1)
# Skip non-alphanumeric
while i >= 0 and not line[i].isalnum():
i -= 1
# Skip alphanumeric
while i >= 0 and line[i].isalnum():
i -= 1
self.cursor_x = max(0, i + 1)
def handle_insert(self, key):
"""Handle insert mode keys."""
try:
if key == 27: # ESC
self.mode = 'normal'
if self.cursor_x > 0:
self.cursor_x -= 1
elif key == 10 or key == 13: # Enter
self._split_line()
elif key == curses.KEY_BACKSPACE or key == 127 or key == 8:
self._backspace()
elif 32 <= key <= 126:
char = chr(key)
self._insert_char(char)
except Exception:
pass
def handle_command(self, key):
"""Handle command mode keys."""
try:
if key == 10 or key == 13: # Enter
cmd = self.command[1:].strip()
if cmd in ["q", "q!"]:
self.running = False
elif cmd == "w":
self._save_file()
elif cmd in ["wq", "wq!", "x", "xq", "x!"]:
self._save_file()
self.running = False
elif cmd.startswith("w "):
self.filename = cmd[2:].strip()
self._save_file()
self.mode = 'normal'
self.command = ""
elif key == 27: # ESC
self.mode = 'normal'
self.command = ""
elif key == curses.KEY_BACKSPACE or key == 127 or key == 8:
if len(self.command) > 1:
self.command = self.command[:-1]
elif 32 <= key <= 126:
self.command += chr(key)
except Exception:
self.mode = 'normal'
self.command = ""
def move_cursor(self, dy, dx):
"""Move cursor with bounds checking."""
if not self.lines:
self.lines = [""]
new_y = self.cursor_y + dy
new_x = self.cursor_x + dx
# Ensure valid Y position
if 0 <= new_y < len(self.lines):
self.cursor_y = new_y
# Ensure valid X position for new line
max_x = len(self.lines[self.cursor_y])
self.cursor_x = max(0, min(new_x, max_x))
elif new_y < 0:
self.cursor_y = 0
self.cursor_x = 0
elif new_y >= len(self.lines):
self.cursor_y = max(0, len(self.lines) - 1)
self.cursor_x = len(self.lines[self.cursor_y])
def save_state(self):
"""Save current state for undo."""
with self.lock:
state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
}
self.undo_stack.append(state)
if len(self.undo_stack) > self.max_undo:
self.undo_stack.pop(0)
self.redo_stack.clear()
def undo(self):
"""Undo last change."""
with self.lock:
if self.undo_stack:
current_state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
}
self.redo_stack.append(current_state)
state = self.undo_stack.pop()
self.lines = state['lines']
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1)
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0)
def redo(self):
"""Redo last undone change."""
with self.lock:
if self.redo_stack:
current_state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
}
self.undo_stack.append(current_state)
state = self.redo_stack.pop()
self.lines = state['lines']
self.cursor_y = min(state['cursor_y'], len(self.lines) - 1)
self.cursor_x = min(state['cursor_x'], len(self.lines[self.cursor_y]) if self.lines else 0)
def _insert_text(self, text):
"""Insert text at cursor position."""
if not text:
return
self.save_state()
lines = text.split('\n')
if len(lines) == 1:
# Single line insert
if self.cursor_y >= len(self.lines):
self.lines.append("")
self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + text + line[self.cursor_x:]
self.cursor_x += len(text)
else:
# Multi-line insert
if self.cursor_y >= len(self.lines):
self.lines.append("")
self.cursor_y = len(self.lines) - 1
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0]
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:]
self.lines[self.cursor_y] = first
for i in range(1, len(lines) - 1):
self.lines.insert(self.cursor_y + i, lines[i])
self.lines.insert(self.cursor_y + len(lines) - 1, last)
self.cursor_y += len(lines) - 1
self.cursor_x = len(lines[-1])
def insert_text(self, text):
"""Thread-safe text insertion."""
try:
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text}))
except:
with self.lock:
self._insert_text(text)
def _delete_char(self):
"""Delete character at cursor."""
self.save_state()
if self.cursor_y < len(self.lines) and self.cursor_x < len(self.lines[self.cursor_y]):
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + line[self.cursor_x+1:]
def delete_char(self):
"""Thread-safe character deletion."""
try:
self.client_sock.send(pickle.dumps({'command': 'delete_char'}))
except:
with self.lock:
self._delete_char()
def _insert_char(self, char):
"""Insert single character."""
if self.cursor_y >= len(self.lines):
self.lines.append("")
self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x] + char + line[self.cursor_x:]
self.cursor_x += 1
def _split_line(self):
"""Split line at cursor."""
if self.cursor_y >= len(self.lines):
self.lines.append("")
self.cursor_y = len(self.lines) - 1
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x]
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:])
self.cursor_y += 1
self.cursor_x = 0
def _backspace(self):
"""Handle backspace key."""
if self.cursor_x > 0:
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x-1] + line[self.cursor_x:]
self.cursor_x -= 1
elif self.cursor_y > 0:
prev_len = len(self.lines[self.cursor_y - 1])
self.lines[self.cursor_y - 1] += self.lines[self.cursor_y]
del self.lines[self.cursor_y]
self.cursor_y -= 1
self.cursor_x = prev_len
def _insert_line(self, line_num, text):
"""Insert a new line."""
self.save_state()
line_num = max(0, min(line_num, len(self.lines)))
self.lines.insert(line_num, text)
def _delete_line(self, line_num):
"""Delete a line."""
self.save_state()
if 0 <= line_num < len(self.lines):
if len(self.lines) > 1:
del self.lines[line_num]
else:
self.lines = [""]
def _set_text(self, text):
"""Set entire text content."""
self.save_state()
self.lines = text.splitlines() if text else [""]
self.cursor_y = 0
self.cursor_x = 0
def set_text(self, text):
"""Thread-safe text setting."""
try:
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
except:
with self.lock:
self._set_text(text)
def _goto_line(self, line_num):
"""Go to specific line."""
line_num = max(0, min(line_num - 1, len(self.lines) - 1))
self.cursor_y = line_num
self.cursor_x = 0
def goto_line(self, line_num):
"""Thread-safe goto line."""
try:
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num}))
except:
with self.lock:
self._goto_line(line_num)
def get_text(self):
"""Get entire text content."""
try:
self.client_sock.send(pickle.dumps({'command': 'get_text'}))
data = self.client_sock.recv(65536)
return pickle.loads(data)
except:
with self.lock:
return '\n'.join(self.lines)
def get_cursor(self):
"""Get cursor position."""
try:
self.client_sock.send(pickle.dumps({'command': 'get_cursor'}))
data = self.client_sock.recv(4096)
return pickle.loads(data)
except:
with self.lock:
return (self.cursor_y, self.cursor_x)
def get_file_info(self):
"""Get file information."""
try:
self.client_sock.send(pickle.dumps({'command': 'get_file_info'}))
data = self.client_sock.recv(4096)
return pickle.loads(data)
except:
with self.lock:
return {
'filename': self.filename,
'lines': len(self.lines),
'cursor': (self.cursor_y, self.cursor_x),
'mode': self.mode
}
def socket_listener(self):
"""Listen for socket commands with error handling."""
while self.running:
try:
data = self.server_sock.recv(65536)
if not data:
break
command = pickle.loads(data)
self.command_queue.put(command)
except socket.timeout:
continue
except OSError:
if self.running:
continue
else:
break
except Exception:
continue
def execute_command(self, command):
"""Execute command with error handling."""
try:
cmd = command.get('command')
if cmd == 'insert_text':
self._insert_text(command.get('text', ''))
elif cmd == 'delete_char':
self._delete_char()
elif cmd == 'save_file':
self._save_file()
elif cmd == 'set_text':
self._set_text(command.get('text', ''))
elif cmd == 'goto_line':
self._goto_line(command.get('line_num', 1))
elif cmd == 'get_text':
result = '\n'.join(self.lines)
self.server_sock.send(pickle.dumps(result))
elif cmd == 'get_cursor':
result = (self.cursor_y, self.cursor_x)
self.server_sock.send(pickle.dumps(result))
elif cmd == 'get_file_info':
result = {
'filename': self.filename,
'lines': len(self.lines),
'cursor': (self.cursor_y, self.cursor_x),
'mode': self.mode
}
self.server_sock.send(pickle.dumps(result))
elif cmd == 'stop':
self.running = False
except Exception:
pass
# Additional public methods for backwards compatibility
def move_cursor_to(self, y, x):
"""Move cursor to specific position."""
with self.lock:
self.cursor_y = max(0, min(y, len(self.lines) - 1))
self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y])))
def get_line(self, line_num):
"""Get specific line."""
with self.lock:
if 0 <= line_num < len(self.lines):
return self.lines[line_num]
return None
def get_lines(self, start, end):
"""Get range of lines."""
with self.lock:
start = max(0, start)
end = min(end, len(self.lines))
return self.lines[start:end]
def insert_at_line(self, line_num, text):
"""Insert text at specific line."""
with self.lock:
self.save_state()
line_num = max(0, min(line_num, len(self.lines)))
self.lines.insert(line_num, text)
def delete_lines(self, start, end):
"""Delete range of lines."""
with self.lock:
self.save_state()
start = max(0, start)
end = min(end, len(self.lines))
if start < end:
del self.lines[start:end]
if not self.lines:
self.lines = [""]
def replace_text(self, start_line, start_col, end_line, end_col, new_text):
"""Replace text in range."""
with self.lock:
self.save_state()
# Validate bounds
start_line = max(0, min(start_line, len(self.lines) - 1))
end_line = max(0, min(end_line, len(self.lines) - 1))
if start_line == end_line:
line = self.lines[start_line]
start_col = max(0, min(start_col, len(line)))
end_col = max(0, min(end_col, len(line)))
self.lines[start_line] = line[:start_col] + new_text + line[end_col:]
else:
first_part = self.lines[start_line][:start_col]
last_part = self.lines[end_line][end_col:]
new_lines = new_text.split('\n')
self.lines[start_line] = first_part + new_lines[0]
del self.lines[start_line + 1:end_line + 1]
for i, new_line in enumerate(new_lines[1:], 1):
self.lines.insert(start_line + i, new_line)
if len(new_lines) > 1:
self.lines[start_line + len(new_lines) - 1] += last_part
else:
self.lines[start_line] += last_part
def search(self, pattern, start_line=0):
"""Search for pattern in text."""
with self.lock:
results = []
try:
for i in range(start_line, len(self.lines)):
matches = re.finditer(pattern, self.lines[i])
for match in matches:
results.append((i, match.start(), match.end()))
except re.error:
pass
return results
def replace_all(self, search_text, replace_text):
"""Replace all occurrences of text."""
with self.lock:
self.save_state()
for i in range(len(self.lines)):
self.lines[i] = self.lines[i].replace(search_text, replace_text)
def select_range(self, start_line, start_col, end_line, end_col):
"""Select text range."""
with self.lock:
self.selection_start = (start_line, start_col)
self.selection_end = (end_line, end_col)
def get_selection(self):
"""Get selected text."""
with self.lock:
if not self.selection_start or not self.selection_end:
return ""
sl, sc = self.selection_start
el, ec = self.selection_end
# Validate bounds
if sl < 0 or sl >= len(self.lines) or el < 0 or el >= len(self.lines):
return ""
if sl == el:
return self.lines[sl][sc:ec]
result = [self.lines[sl][sc:]]
for i in range(sl + 1, el):
if i < len(self.lines):
result.append(self.lines[i])
if el < len(self.lines):
result.append(self.lines[el][:ec])
return '\n'.join(result)
def delete_selection(self):
"""Delete selected text."""
with self.lock:
if not self.selection_start or not self.selection_end:
return
self.save_state()
sl, sc = self.selection_start
el, ec = self.selection_end
if 0 <= sl < len(self.lines) and 0 <= el < len(self.lines):
self.replace_text(sl, sc, el, ec, "")
self.selection_start = None
self.selection_end = None
def apply_search_replace_block(self, search_block, replace_block):
"""Apply search and replace on block."""
with self.lock:
self.save_state()
search_lines = search_block.splitlines()
replace_lines = replace_block.splitlines()
for i in range(len(self.lines) - len(search_lines) + 1):
match = True
for j, search_line in enumerate(search_lines):
if i + j >= len(self.lines):
match = False
break
if self.lines[i + j].strip() != search_line.strip():
match = False
break
if match:
# Preserve indentation
indent = len(self.lines[i]) - len(self.lines[i].lstrip())
indented_replace = [' ' * indent + line for line in replace_lines]
self.lines[i:i+len(search_lines)] = indented_replace
return True
return False
def apply_diff(self, diff_text):
"""Apply unified diff."""
with self.lock:
self.save_state()
try:
lines = diff_text.split('\n')
start_line = 0
for line in lines:
if line.startswith('@@'):
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line)
if match:
start_line = int(match.group(1)) - 1
elif line.startswith('-'):
if start_line < len(self.lines):
del self.lines[start_line]
elif line.startswith('+'):
self.lines.insert(start_line, line[1:])
start_line += 1
elif line and not line.startswith('\\'):
start_line += 1
except Exception:
pass
def get_context(self, line_num, context_lines=3):
"""Get lines around specific line."""
with self.lock:
start = max(0, line_num - context_lines)
end = min(len(self.lines), line_num + context_lines + 1)
return self.get_lines(start, end)
def count_lines(self):
"""Count total lines."""
with self.lock:
return len(self.lines)
def close(self):
"""Close the editor."""
self.stop()
def is_running(self):
"""Check if editor is running."""
return self.running
def wait(self, timeout=None):
"""Wait for editor to finish."""
if self.thread and self.thread.is_alive():
self.thread.join(timeout=timeout)
return not self.thread.is_alive()
return True
def __enter__(self):
"""Context manager entry."""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
if exc_type:
self._exception_occurred = True
self.stop()
return False
def __del__(self):
"""Destructor for cleanup."""
self._cleanup()
def main():
"""Main entry point with error handling."""
editor = None
try:
filename = sys.argv[1] if len(sys.argv) > 1 else None
# Parse additional arguments
auto_save = '--auto-save' in sys.argv
# Create and start editor
editor = RPEditor(filename, auto_save=auto_save)
editor.start()
# Wait for editor to finish
if editor.thread:
editor.thread.join()
except KeyboardInterrupt:
pass
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
finally:
if editor:
editor.stop()
# Ensure screen is cleared
os.system('clear' if os.name != 'nt' else 'cls')
if __name__ == "__main__":
if "rpe" in sys.argv[0]:
main()

587
pr/editor2.py Normal file
View File

@ -0,0 +1,587 @@
#!/usr/bin/env python3
import curses
import threading
import sys
import os
import re
import socket
import pickle
import queue
class RPEditor:
def __init__(self, filename=None):
self.filename = filename
self.lines = [""]
self.cursor_y = 0
self.cursor_x = 0
self.mode = 'normal'
self.command = ""
self.stdscr = None
self.running = False
self.thread = None
self.socket_thread = None
self.prev_key = None
self.clipboard = ""
self.undo_stack = []
self.redo_stack = []
self.selection_start = None
self.selection_end = None
self.max_undo = 100
self.lock = threading.RLock()
self.client_sock, self.server_sock = socket.socketpair()
self.command_queue = queue.Queue()
if filename:
self.load_file()
def load_file(self):
try:
with open(self.filename, 'r') as f:
self.lines = f.read().splitlines()
if not self.lines:
self.lines = [""]
except:
self.lines = [""]
def _save_file(self):
with self.lock:
if self.filename:
with open(self.filename, 'w') as f:
f.write('\n'.join(self.lines))
def save_file(self):
self.client_sock.send(pickle.dumps({'command': 'save_file'}))
def start(self):
self.running = True
self.socket_thread = threading.Thread(target=self.socket_listener)
self.socket_thread.start()
self.thread = threading.Thread(target=self.run)
self.thread.start()
def stop(self):
self.client_sock.send(pickle.dumps({'command': 'stop'}))
self.running = False
if self.stdscr:
curses.endwin()
if self.thread:
self.thread.join()
if self.socket_thread:
self.socket_thread.join()
self.client_sock.close()
self.server_sock.close()
def run(self):
curses.wrapper(self.main_loop)
def main_loop(self, stdscr):
self.stdscr = stdscr
curses.curs_set(1)
self.stdscr.keypad(True)
while self.running:
with self.lock:
self.draw()
try:
while True:
command = self.command_queue.get_nowait()
with self.lock:
self.execute_command(command)
except queue.Empty:
pass
key = self.stdscr.getch()
with self.lock:
self.handle_key(key)
def draw(self):
self.stdscr.clear()
height, width = self.stdscr.getmaxyx()
for i, line in enumerate(self.lines):
if i < height - 1:
self.stdscr.addstr(i, 0, line[:width])
status = f"{self.mode.upper()} | {self.filename or 'untitled'} | {self.cursor_y+1}:{self.cursor_x+1}"
self.stdscr.addstr(height - 1, 0, status[:width])
if self.mode == 'command':
self.stdscr.addstr(height - 1, 0, self.command[:width])
self.stdscr.move(self.cursor_y, min(self.cursor_x, width - 1))
self.stdscr.refresh()
def handle_key(self, key):
if self.mode == 'normal':
self.handle_normal(key)
elif self.mode == 'insert':
self.handle_insert(key)
elif self.mode == 'command':
self.handle_command(key)
def handle_normal(self, key):
if key == ord('h') or key == curses.KEY_LEFT:
self.move_cursor(0, -1)
elif key == ord('j') or key == curses.KEY_DOWN:
self.move_cursor(1, 0)
elif key == ord('k') or key == curses.KEY_UP:
self.move_cursor(-1, 0)
elif key == ord('l') or key == curses.KEY_RIGHT:
self.move_cursor(0, 1)
elif key == ord('i'):
self.mode = 'insert'
elif key == ord(':'):
self.mode = 'command'
self.command = ":"
elif key == ord('x'):
self._delete_char()
elif key == ord('a'):
self.cursor_x += 1
self.mode = 'insert'
elif key == ord('A'):
self.cursor_x = len(self.lines[self.cursor_y])
self.mode = 'insert'
elif key == ord('o'):
self._insert_line(self.cursor_y + 1, "")
self.cursor_y += 1
self.cursor_x = 0
self.mode = 'insert'
elif key == ord('O'):
self._insert_line(self.cursor_y, "")
self.cursor_x = 0
self.mode = 'insert'
elif key == ord('d') and self.prev_key == ord('d'):
self.clipboard = self.lines[self.cursor_y]
self._delete_line(self.cursor_y)
if self.cursor_y >= len(self.lines):
self.cursor_y = len(self.lines) - 1
self.cursor_x = 0
elif key == ord('y') and self.prev_key == ord('y'):
self.clipboard = self.lines[self.cursor_y]
elif key == ord('p'):
self._insert_line(self.cursor_y + 1, self.clipboard)
self.cursor_y += 1
self.cursor_x = 0
elif key == ord('P'):
self._insert_line(self.cursor_y, self.clipboard)
self.cursor_x = 0
elif key == ord('w'):
line = self.lines[self.cursor_y]
i = self.cursor_x
while i < len(line) and not line[i].isalnum():
i += 1
while i < len(line) and line[i].isalnum():
i += 1
self.cursor_x = i
elif key == ord('b'):
line = self.lines[self.cursor_y]
i = self.cursor_x - 1
while i >= 0 and not line[i].isalnum():
i -= 1
while i >= 0 and line[i].isalnum():
i -= 1
self.cursor_x = i + 1
elif key == ord('0'):
self.cursor_x = 0
elif key == ord('$'):
self.cursor_x = len(self.lines[self.cursor_y])
elif key == ord('g'):
if self.prev_key == ord('g'):
self.cursor_y = 0
self.cursor_x = 0
elif key == ord('G'):
self.cursor_y = len(self.lines) - 1
self.cursor_x = 0
elif key == ord('u'):
self.undo()
elif key == ord('r') and self.prev_key == 18:
self.redo()
self.prev_key = key
def handle_insert(self, key):
if key == 27:
self.mode = 'normal'
if self.cursor_x > 0:
self.cursor_x -= 1
elif key == 10:
self._split_line()
elif key == curses.KEY_BACKSPACE or key == 127:
self._backspace()
elif 32 <= key <= 126:
char = chr(key)
self._insert_char(char)
def handle_command(self, key):
if key == 10:
cmd = self.command[1:]
if cmd == "q" or cmd == 'q!':
self.running = False
elif cmd == "w":
self._save_file()
elif cmd == "wq" or cmd == "wq!" or cmd == "x" or cmd == "xq" or cmd == "x!":
self._save_file()
self.running = False
elif cmd.startswith("w "):
self.filename = cmd[2:]
self._save_file()
elif cmd == "wq":
self._save_file()
self.running = False
self.mode = 'normal'
self.command = ""
elif key == 27:
self.mode = 'normal'
self.command = ""
elif key == curses.KEY_BACKSPACE or key == 127:
if len(self.command) > 1:
self.command = self.command[:-1]
elif 32 <= key <= 126:
self.command += chr(key)
def move_cursor(self, dy, dx):
new_y = self.cursor_y + dy
new_x = self.cursor_x + dx
if 0 <= new_y < len(self.lines):
self.cursor_y = new_y
self.cursor_x = max(0, min(new_x, len(self.lines[self.cursor_y])))
def save_state(self):
with self.lock:
state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
}
self.undo_stack.append(state)
if len(self.undo_stack) > self.max_undo:
self.undo_stack.pop(0)
self.redo_stack.clear()
def undo(self):
with self.lock:
if self.undo_stack:
current_state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
}
self.redo_stack.append(current_state)
state = self.undo_stack.pop()
self.lines = state['lines']
self.cursor_y = state['cursor_y']
self.cursor_x = state['cursor_x']
def redo(self):
with self.lock:
if self.redo_stack:
current_state = {
'lines': [line for line in self.lines],
'cursor_y': self.cursor_y,
'cursor_x': self.cursor_x
}
self.undo_stack.append(current_state)
state = self.redo_stack.pop()
self.lines = state['lines']
self.cursor_y = state['cursor_y']
self.cursor_x = state['cursor_x']
def _insert_text(self, text):
self.save_state()
lines = text.split('\n')
if len(lines) == 1:
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + text + self.lines[self.cursor_y][self.cursor_x:]
self.cursor_x += len(text)
else:
first = self.lines[self.cursor_y][:self.cursor_x] + lines[0]
last = lines[-1] + self.lines[self.cursor_y][self.cursor_x:]
self.lines[self.cursor_y] = first
for i in range(1, len(lines)-1):
self.lines.insert(self.cursor_y + i, lines[i])
self.lines.insert(self.cursor_y + len(lines) - 1, last)
self.cursor_y += len(lines) - 1
self.cursor_x = len(lines[-1])
def insert_text(self, text):
self.client_sock.send(pickle.dumps({'command': 'insert_text', 'text': text}))
def _delete_char(self):
self.save_state()
if self.cursor_x < len(self.lines[self.cursor_y]):
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + self.lines[self.cursor_y][self.cursor_x+1:]
def delete_char(self):
self.client_sock.send(pickle.dumps({'command': 'delete_char'}))
def _insert_char(self, char):
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x] + char + self.lines[self.cursor_y][self.cursor_x:]
self.cursor_x += 1
def _split_line(self):
line = self.lines[self.cursor_y]
self.lines[self.cursor_y] = line[:self.cursor_x]
self.lines.insert(self.cursor_y + 1, line[self.cursor_x:])
self.cursor_y += 1
self.cursor_x = 0
def _backspace(self):
if self.cursor_x > 0:
self.lines[self.cursor_y] = self.lines[self.cursor_y][:self.cursor_x-1] + self.lines[self.cursor_y][self.cursor_x:]
self.cursor_x -= 1
elif self.cursor_y > 0:
prev_len = len(self.lines[self.cursor_y - 1])
self.lines[self.cursor_y - 1] += self.lines[self.cursor_y]
del self.lines[self.cursor_y]
self.cursor_y -= 1
self.cursor_x = prev_len
def _insert_line(self, line_num, text):
self.save_state()
line_num = max(0, min(line_num, len(self.lines)))
self.lines.insert(line_num, text)
def _delete_line(self, line_num):
self.save_state()
if 0 <= line_num < len(self.lines):
if len(self.lines) > 1:
del self.lines[line_num]
else:
self.lines = [""]
def _set_text(self, text):
self.save_state()
self.lines = text.splitlines() if text else [""]
self.cursor_y = 0
self.cursor_x = 0
def set_text(self, text):
self.client_sock.send(pickle.dumps({'command': 'set_text', 'text': text}))
def _goto_line(self, line_num):
line_num = max(0, min(line_num, len(self.lines) - 1))
self.cursor_y = line_num
self.cursor_x = 0
def goto_line(self, line_num):
self.client_sock.send(pickle.dumps({'command': 'goto_line', 'line_num': line_num}))
def get_text(self):
self.client_sock.send(pickle.dumps({'command': 'get_text'}))
try:
return pickle.loads(self.client_sock.recv(4096))
except:
return ''
def get_cursor(self):
self.client_sock.send(pickle.dumps({'command': 'get_cursor'}))
try:
return pickle.loads(self.client_sock.recv(4096))
except:
return (0, 0)
def get_file_info(self):
self.client_sock.send(pickle.dumps({'command': 'get_file_info'}))
try:
return pickle.loads(self.client_sock.recv(4096))
except:
return {}
def socket_listener(self):
while self.running:
try:
data = self.server_sock.recv(4096)
if not data:
break
command = pickle.loads(data)
self.command_queue.put(command)
except OSError:
break
def execute_command(self, command):
cmd = command.get('command')
if cmd == 'insert_text':
self._insert_text(command['text'])
elif cmd == 'delete_char':
self._delete_char()
elif cmd == 'save_file':
self._save_file()
elif cmd == 'set_text':
self._set_text(command['text'])
elif cmd == 'goto_line':
self._goto_line(command['line_num'])
elif cmd == 'get_text':
result = '\n'.join(self.lines)
try:
self.server_sock.send(pickle.dumps(result))
except:
pass
elif cmd == 'get_cursor':
result = (self.cursor_y, self.cursor_x)
try:
self.server_sock.send(pickle.dumps(result))
except:
pass
elif cmd == 'get_file_info':
result = {
'filename': self.filename,
'lines': len(self.lines),
'cursor': (self.cursor_y, self.cursor_x),
'mode': self.mode
}
try:
self.server_sock.send(pickle.dumps(result))
except:
pass
elif cmd == 'stop':
self.running = False
def move_cursor_to(self, y, x):
with self.lock:
self.cursor_y = max(0, min(y, len(self.lines)-1))
self.cursor_x = max(0, min(x, len(self.lines[self.cursor_y])))
def get_line(self, line_num):
with self.lock:
if 0 <= line_num < len(self.lines):
return self.lines[line_num]
return None
def get_lines(self, start, end):
with self.lock:
start = max(0, start)
end = min(end, len(self.lines))
return self.lines[start:end]
def insert_at_line(self, line_num, text):
with self.lock:
self.save_state()
line_num = max(0, min(line_num, len(self.lines)))
self.lines.insert(line_num, text)
def delete_lines(self, start, end):
with self.lock:
self.save_state()
start = max(0, start)
end = min(end, len(self.lines))
if start < end:
del self.lines[start:end]
if not self.lines:
self.lines = [""]
def replace_text(self, start_line, start_col, end_line, end_col, new_text):
with self.lock:
self.save_state()
if start_line == end_line:
line = self.lines[start_line]
self.lines[start_line] = line[:start_col] + new_text + line[end_col:]
else:
first_part = self.lines[start_line][:start_col]
last_part = self.lines[end_line][end_col:]
new_lines = new_text.split('\n')
self.lines[start_line] = first_part + new_lines[0]
del self.lines[start_line + 1:end_line + 1]
for i, new_line in enumerate(new_lines[1:], 1):
self.lines.insert(start_line + i, new_line)
if len(new_lines) > 1:
self.lines[start_line + len(new_lines) - 1] += last_part
else:
self.lines[start_line] += last_part
def search(self, pattern, start_line=0):
with self.lock:
results = []
for i in range(start_line, len(self.lines)):
matches = re.finditer(pattern, self.lines[i])
for match in matches:
results.append((i, match.start(), match.end()))
return results
def replace_all(self, search_text, replace_text):
with self.lock:
self.save_state()
for i in range(len(self.lines)):
self.lines[i] = self.lines[i].replace(search_text, replace_text)
def select_range(self, start_line, start_col, end_line, end_col):
with self.lock:
self.selection_start = (start_line, start_col)
self.selection_end = (end_line, end_col)
def get_selection(self):
with self.lock:
if not self.selection_start or not self.selection_end:
return ""
sl, sc = self.selection_start
el, ec = self.selection_end
if sl == el:
return self.lines[sl][sc:ec]
result = [self.lines[sl][sc:]]
for i in range(sl + 1, el):
result.append(self.lines[i])
result.append(self.lines[el][:ec])
return '\n'.join(result)
def delete_selection(self):
with self.lock:
if not self.selection_start or not self.selection_end:
return
self.save_state()
sl, sc = self.selection_start
el, ec = self.selection_end
self.replace_text(sl, sc, el, ec, "")
self.selection_start = None
self.selection_end = None
def apply_search_replace_block(self, search_block, replace_block):
with self.lock:
self.save_state()
search_lines = search_block.splitlines()
replace_lines = replace_block.splitlines()
for i in range(len(self.lines) - len(search_lines) + 1):
match = True
for j, search_line in enumerate(search_lines):
if self.lines[i + j].strip() != search_line.strip():
match = False
break
if match:
indent = len(self.lines[i]) - len(self.lines[i].lstrip())
indented_replace = [' ' * indent + line for line in replace_lines]
self.lines[i:i+len(search_lines)] = indented_replace
return True
return False
def apply_diff(self, diff_text):
with self.lock:
self.save_state()
lines = diff_text.split('\n')
for line in lines:
if line.startswith('@@'):
match = re.search(r'@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@', line)
if match:
start_line = int(match.group(1)) - 1
elif line.startswith('-'):
if start_line < len(self.lines):
del self.lines[start_line]
elif line.startswith('+'):
self.lines.insert(start_line, line[1:])
start_line += 1
def get_context(self, line_num, context_lines=3):
with self.lock:
start = max(0, line_num - context_lines)
end = min(len(self.lines), line_num + context_lines + 1)
return self.get_lines(start, end)
def count_lines(self):
with self.lock:
return len(self.lines)
def close(self):
self.running = False
self.stop()
if self.thread:
self.thread.join()
def main():
filename = sys.argv[1] if len(sys.argv) > 1 else None
editor = RPEditor(filename)
editor.start()
editor.thread.join()
if __name__ == "__main__":
main()

7
pr/memory/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from .knowledge_store import KnowledgeStore, KnowledgeEntry
from .semantic_index import SemanticIndex
from .conversation_memory import ConversationMemory
from .fact_extractor import FactExtractor
__all__ = ['KnowledgeStore', 'KnowledgeEntry', 'SemanticIndex',
'ConversationMemory', 'FactExtractor']

View File

@ -0,0 +1,259 @@
import json
import sqlite3
import time
from typing import List, Dict, Any, Optional
class ConversationMemory:
def __init__(self, db_path: str):
self.db_path = db_path
self._initialize_memory()
def _initialize_memory(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversation_history (
conversation_id TEXT PRIMARY KEY,
session_id TEXT,
started_at REAL NOT NULL,
ended_at REAL,
message_count INTEGER DEFAULT 0,
summary TEXT,
topics TEXT,
metadata TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS conversation_messages (
message_id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
timestamp REAL NOT NULL,
tool_calls TEXT,
metadata TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversation_history(conversation_id)
)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_conv_session ON conversation_history(session_id)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_conv_started ON conversation_history(started_at DESC)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_msg_conversation ON conversation_messages(conversation_id)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_msg_timestamp ON conversation_messages(timestamp)
''')
conn.commit()
conn.close()
def create_conversation(self, conversation_id: str, session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO conversation_history
(conversation_id, session_id, started_at, metadata)
VALUES (?, ?, ?, ?)
''', (
conversation_id,
session_id,
time.time(),
json.dumps(metadata) if metadata else None
))
conn.commit()
conn.close()
def add_message(self, conversation_id: str, message_id: str, role: str,
content: str, tool_calls: Optional[List[Dict[str, Any]]] = None,
metadata: Optional[Dict[str, Any]] = None):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO conversation_messages
(message_id, conversation_id, role, content, timestamp, tool_calls, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
message_id,
conversation_id,
role,
content,
time.time(),
json.dumps(tool_calls) if tool_calls else None,
json.dumps(metadata) if metadata else None
))
cursor.execute('''
UPDATE conversation_history
SET message_count = message_count + 1
WHERE conversation_id = ?
''', (conversation_id,))
conn.commit()
conn.close()
def get_conversation_messages(self, conversation_id: str,
limit: Optional[int] = None) -> List[Dict[str, Any]]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
if limit:
cursor.execute('''
SELECT message_id, role, content, timestamp, tool_calls, metadata
FROM conversation_messages
WHERE conversation_id = ?
ORDER BY timestamp DESC
LIMIT ?
''', (conversation_id, limit))
else:
cursor.execute('''
SELECT message_id, role, content, timestamp, tool_calls, metadata
FROM conversation_messages
WHERE conversation_id = ?
ORDER BY timestamp ASC
''', (conversation_id,))
messages = []
for row in cursor.fetchall():
messages.append({
'message_id': row[0],
'role': row[1],
'content': row[2],
'timestamp': row[3],
'tool_calls': json.loads(row[4]) if row[4] else None,
'metadata': json.loads(row[5]) if row[5] else None
})
conn.close()
return messages
def update_conversation_summary(self, conversation_id: str, summary: str,
topics: Optional[List[str]] = None):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
UPDATE conversation_history
SET summary = ?, topics = ?, ended_at = ?
WHERE conversation_id = ?
''', (summary, json.dumps(topics) if topics else None, time.time(), conversation_id))
conn.commit()
conn.close()
def search_conversations(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
SELECT DISTINCT h.conversation_id, h.session_id, h.started_at,
h.message_count, h.summary, h.topics
FROM conversation_history h
LEFT JOIN conversation_messages m ON h.conversation_id = m.conversation_id
WHERE h.summary LIKE ? OR h.topics LIKE ? OR m.content LIKE ?
ORDER BY h.started_at DESC
LIMIT ?
''', (f'%{query}%', f'%{query}%', f'%{query}%', limit))
conversations = []
for row in cursor.fetchall():
conversations.append({
'conversation_id': row[0],
'session_id': row[1],
'started_at': row[2],
'message_count': row[3],
'summary': row[4],
'topics': json.loads(row[5]) if row[5] else []
})
conn.close()
return conversations
def get_recent_conversations(self, limit: int = 10,
session_id: Optional[str] = None) -> List[Dict[str, Any]]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
if session_id:
cursor.execute('''
SELECT conversation_id, session_id, started_at, ended_at,
message_count, summary, topics
FROM conversation_history
WHERE session_id = ?
ORDER BY started_at DESC
LIMIT ?
''', (session_id, limit))
else:
cursor.execute('''
SELECT conversation_id, session_id, started_at, ended_at,
message_count, summary, topics
FROM conversation_history
ORDER BY started_at DESC
LIMIT ?
''', (limit,))
conversations = []
for row in cursor.fetchall():
conversations.append({
'conversation_id': row[0],
'session_id': row[1],
'started_at': row[2],
'ended_at': row[3],
'message_count': row[4],
'summary': row[5],
'topics': json.loads(row[6]) if row[6] else []
})
conn.close()
return conversations
def delete_conversation(self, conversation_id: str) -> bool:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM conversation_messages WHERE conversation_id = ?',
(conversation_id,))
cursor.execute('DELETE FROM conversation_history WHERE conversation_id = ?',
(conversation_id,))
deleted = cursor.rowcount > 0
conn.commit()
conn.close()
return deleted
def get_statistics(self) -> Dict[str, Any]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM conversation_history')
total_conversations = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(*) FROM conversation_messages')
total_messages = cursor.fetchone()[0]
cursor.execute('SELECT SUM(message_count) FROM conversation_history')
total_message_count = cursor.fetchone()[0] or 0
cursor.execute('''
SELECT AVG(message_count) FROM conversation_history WHERE message_count > 0
''')
avg_messages = cursor.fetchone()[0] or 0
conn.close()
return {
'total_conversations': total_conversations,
'total_messages': total_messages,
'average_messages_per_conversation': round(avg_messages, 2)
}

146
pr/memory/fact_extractor.py Normal file
View File

@ -0,0 +1,146 @@
import re
import json
from typing import List, Dict, Any, Set
from collections import defaultdict
class FactExtractor:
def __init__(self):
self.fact_patterns = [
(r'([A-Z][a-z]+ [A-Z][a-z]+) is (a|an) ([^.]+)', 'definition'),
(r'([A-Z][a-z]+) (was|is) (born|created|founded) in (\d{4})', 'temporal'),
(r'([A-Z][a-z]+) (invented|created|developed) ([^.]+)', 'attribution'),
(r'([^.]+) (costs?|worth) (\$[\d,]+)', 'numeric'),
(r'([A-Z][a-z]+) (lives?|works?|located) in ([A-Z][a-z]+)', 'location'),
]
def extract_facts(self, text: str) -> List[Dict[str, Any]]:
facts = []
for pattern, fact_type in self.fact_patterns:
matches = re.finditer(pattern, text)
for match in matches:
facts.append({
'type': fact_type,
'text': match.group(0),
'components': match.groups(),
'confidence': 0.7
})
noun_phrases = self._extract_noun_phrases(text)
for phrase in noun_phrases:
if len(phrase.split()) >= 2:
facts.append({
'type': 'entity',
'text': phrase,
'components': [phrase],
'confidence': 0.5
})
return facts
def _extract_noun_phrases(self, text: str) -> List[str]:
sentences = re.split(r'[.!?]', text)
phrases = []
for sentence in sentences:
words = sentence.split()
current_phrase = []
for word in words:
if word and word[0].isupper() and len(word) > 1:
current_phrase.append(word)
else:
if len(current_phrase) >= 2:
phrases.append(' '.join(current_phrase))
current_phrase = []
if len(current_phrase) >= 2:
phrases.append(' '.join(current_phrase))
return list(set(phrases))
def extract_key_terms(self, text: str, top_k: int = 10) -> List[tuple]:
words = re.findall(r'\b[a-z]{4,}\b', text.lower())
stopwords = {
'this', 'that', 'these', 'those', 'what', 'which', 'where', 'when',
'with', 'from', 'have', 'been', 'were', 'will', 'would', 'could',
'should', 'about', 'their', 'there', 'other', 'than', 'then', 'them',
'some', 'more', 'very', 'such', 'into', 'through', 'during', 'before',
'after', 'above', 'below', 'between', 'under', 'again', 'further',
'once', 'here', 'both', 'each', 'doing', 'only', 'over', 'same',
'being', 'does', 'just', 'also', 'make', 'made', 'know', 'like'
}
filtered_words = [w for w in words if w not in stopwords]
word_freq = defaultdict(int)
for word in filtered_words:
word_freq[word] += 1
sorted_terms = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
return sorted_terms[:top_k]
def extract_relationships(self, text: str) -> List[Dict[str, Any]]:
relationships = []
relationship_patterns = [
(r'([A-Z][a-z]+) (works for|employed by|member of) ([A-Z][a-z]+)', 'employment'),
(r'([A-Z][a-z]+) (owns|has|possesses) ([^.]+)', 'ownership'),
(r'([A-Z][a-z]+) (located in|part of|belongs to) ([A-Z][a-z]+)', 'location'),
(r'([A-Z][a-z]+) (uses|utilizes|implements) ([^.]+)', 'usage'),
]
for pattern, rel_type in relationship_patterns:
matches = re.finditer(pattern, text)
for match in matches:
relationships.append({
'type': rel_type,
'subject': match.group(1),
'predicate': match.group(2),
'object': match.group(3),
'confidence': 0.6
})
return relationships
def extract_metadata(self, text: str) -> Dict[str, Any]:
word_count = len(text.split())
sentence_count = len(re.split(r'[.!?]', text))
urls = re.findall(r'https?://[^\s]+', text)
email_addresses = re.findall(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text)
dates = re.findall(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b|\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b', text)
numbers = re.findall(r'\b\d+(?:,\d{3})*(?:\.\d+)?\b', text)
return {
'word_count': word_count,
'sentence_count': sentence_count,
'avg_words_per_sentence': round(word_count / max(sentence_count, 1), 2),
'urls': urls,
'email_addresses': email_addresses,
'dates': dates,
'numeric_values': numbers,
'has_code': bool(re.search(r'```|def |class |import |function ', text)),
'has_questions': bool(re.search(r'\?', text))
}
def categorize_content(self, text: str) -> List[str]:
categories = []
category_keywords = {
'programming': ['code', 'function', 'class', 'variable', 'programming', 'software', 'debug'],
'data': ['data', 'database', 'query', 'table', 'record', 'statistics', 'analysis'],
'documentation': ['documentation', 'guide', 'tutorial', 'manual', 'readme', 'explain'],
'configuration': ['config', 'settings', 'configuration', 'setup', 'install', 'deployment'],
'testing': ['test', 'testing', 'validate', 'verification', 'quality', 'assertion'],
'research': ['research', 'study', 'analysis', 'investigation', 'findings', 'results'],
'planning': ['plan', 'planning', 'schedule', 'roadmap', 'milestone', 'timeline'],
}
text_lower = text.lower()
for category, keywords in category_keywords.items():
if any(keyword in text_lower for keyword in keywords):
categories.append(category)
return categories if categories else ['general']

View File

@ -0,0 +1,265 @@
import json
import sqlite3
import time
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from .semantic_index import SemanticIndex
@dataclass
class KnowledgeEntry:
entry_id: str
category: str
content: str
metadata: Dict[str, Any]
created_at: float
updated_at: float
access_count: int = 0
importance_score: float = 1.0
def to_dict(self) -> Dict[str, Any]:
return {
'entry_id': self.entry_id,
'category': self.category,
'content': self.content,
'metadata': self.metadata,
'created_at': self.created_at,
'updated_at': self.updated_at,
'access_count': self.access_count,
'importance_score': self.importance_score
}
class KnowledgeStore:
def __init__(self, db_path: str):
self.db_path = db_path
self.semantic_index = SemanticIndex()
self._initialize_store()
self._load_index()
def _initialize_store(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS knowledge_entries (
entry_id TEXT PRIMARY KEY,
category TEXT NOT NULL,
content TEXT NOT NULL,
metadata TEXT,
created_at REAL NOT NULL,
updated_at REAL NOT NULL,
access_count INTEGER DEFAULT 0,
importance_score REAL DEFAULT 1.0
)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_category ON knowledge_entries(category)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_importance ON knowledge_entries(importance_score DESC)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_created ON knowledge_entries(created_at DESC)
''')
conn.commit()
conn.close()
def _load_index(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT entry_id, content FROM knowledge_entries')
for row in cursor.fetchall():
self.semantic_index.add_document(row[0], row[1])
conn.close()
def add_entry(self, entry: KnowledgeEntry):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO knowledge_entries
(entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
entry.entry_id,
entry.category,
entry.content,
json.dumps(entry.metadata),
entry.created_at,
entry.updated_at,
entry.access_count,
entry.importance_score
))
conn.commit()
conn.close()
self.semantic_index.add_document(entry.entry_id, entry.content)
def get_entry(self, entry_id: str) -> Optional[KnowledgeEntry]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries
WHERE entry_id = ?
''', (entry_id,))
row = cursor.fetchone()
if row:
cursor.execute('''
UPDATE knowledge_entries
SET access_count = access_count + 1
WHERE entry_id = ?
''', (entry_id,))
conn.commit()
conn.close()
return KnowledgeEntry(
entry_id=row[0],
category=row[1],
content=row[2],
metadata=json.loads(row[3]) if row[3] else {},
created_at=row[4],
updated_at=row[5],
access_count=row[6] + 1,
importance_score=row[7]
)
conn.close()
return None
def search_entries(self, query: str, category: Optional[str] = None,
top_k: int = 5) -> List[KnowledgeEntry]:
search_results = self.semantic_index.search(query, top_k * 2)
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
entries = []
for entry_id, score in search_results:
if category:
cursor.execute('''
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries
WHERE entry_id = ? AND category = ?
''', (entry_id, category))
else:
cursor.execute('''
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries
WHERE entry_id = ?
''', (entry_id,))
row = cursor.fetchone()
if row:
entry = KnowledgeEntry(
entry_id=row[0],
category=row[1],
content=row[2],
metadata=json.loads(row[3]) if row[3] else {},
created_at=row[4],
updated_at=row[5],
access_count=row[6],
importance_score=row[7]
)
entries.append(entry)
if len(entries) >= top_k:
break
conn.close()
return entries
def get_by_category(self, category: str, limit: int = 20) -> List[KnowledgeEntry]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
SELECT entry_id, category, content, metadata, created_at, updated_at, access_count, importance_score
FROM knowledge_entries
WHERE category = ?
ORDER BY importance_score DESC, created_at DESC
LIMIT ?
''', (category, limit))
entries = []
for row in cursor.fetchall():
entries.append(KnowledgeEntry(
entry_id=row[0],
category=row[1],
content=row[2],
metadata=json.loads(row[3]) if row[3] else {},
created_at=row[4],
updated_at=row[5],
access_count=row[6],
importance_score=row[7]
))
conn.close()
return entries
def update_importance(self, entry_id: str, importance_score: float):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
UPDATE knowledge_entries
SET importance_score = ?, updated_at = ?
WHERE entry_id = ?
''', (importance_score, time.time(), entry_id))
conn.commit()
conn.close()
def delete_entry(self, entry_id: str) -> bool:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM knowledge_entries WHERE entry_id = ?', (entry_id,))
deleted = cursor.rowcount > 0
conn.commit()
conn.close()
if deleted:
self.semantic_index.remove_document(entry_id)
return deleted
def get_statistics(self) -> Dict[str, Any]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT COUNT(*) FROM knowledge_entries')
total_entries = cursor.fetchone()[0]
cursor.execute('SELECT COUNT(DISTINCT category) FROM knowledge_entries')
total_categories = cursor.fetchone()[0]
cursor.execute('''
SELECT category, COUNT(*) as count
FROM knowledge_entries
GROUP BY category
ORDER BY count DESC
''')
category_counts = {row[0]: row[1] for row in cursor.fetchall()}
cursor.execute('SELECT SUM(access_count) FROM knowledge_entries')
total_accesses = cursor.fetchone()[0] or 0
conn.close()
return {
'total_entries': total_entries,
'total_categories': total_categories,
'category_distribution': category_counts,
'total_accesses': total_accesses,
'vocabulary_size': len(self.semantic_index.vocabulary)
}

View File

@ -0,0 +1,87 @@
import math
import re
from collections import Counter, defaultdict
from typing import List, Dict, Tuple, Set
class SemanticIndex:
def __init__(self):
self.documents: Dict[str, str] = {}
self.vocabulary: Set[str] = set()
self.idf_scores: Dict[str, float] = {}
self.doc_vectors: Dict[str, Dict[str, float]] = {}
def _tokenize(self, text: str) -> List[str]:
text = text.lower()
text = re.sub(r'[^a-z0-9\s]', ' ', text)
tokens = text.split()
return tokens
def _compute_tf(self, tokens: List[str]) -> Dict[str, float]:
term_count = Counter(tokens)
total_terms = len(tokens)
return {term: count / total_terms for term, count in term_count.items()}
def _compute_idf(self):
doc_count = len(self.documents)
if doc_count == 0:
return
token_doc_count = defaultdict(int)
for doc_id, doc_text in self.documents.items():
tokens = set(self._tokenize(doc_text))
for token in tokens:
token_doc_count[token] += 1
if doc_count == 1:
self.idf_scores = {token: 1.0 for token in token_doc_count}
else:
self.idf_scores = {
token: math.log(doc_count / count)
for token, count in token_doc_count.items()
}
def add_document(self, doc_id: str, text: str):
self.documents[doc_id] = text
tokens = self._tokenize(text)
self.vocabulary.update(tokens)
self._compute_idf()
tf_scores = self._compute_tf(tokens)
self.doc_vectors[doc_id] = {
token: tf_scores.get(token, 0) * self.idf_scores.get(token, 0)
for token in tokens
}
def remove_document(self, doc_id: str):
if doc_id in self.documents:
del self.documents[doc_id]
if doc_id in self.doc_vectors:
del self.doc_vectors[doc_id]
self._compute_idf()
def search(self, query: str, top_k: int = 5) -> List[Tuple[str, float]]:
query_tokens = self._tokenize(query)
query_tf = self._compute_tf(query_tokens)
query_vector = {
token: query_tf.get(token, 0) * self.idf_scores.get(token, 0)
for token in query_tokens
}
scores = []
for doc_id, doc_vector in self.doc_vectors.items():
similarity = self._cosine_similarity(query_vector, doc_vector)
scores.append((doc_id, similarity))
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k]
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
dot_product = sum(vec1.get(token, 0) * vec2.get(token, 0) for token in set(vec1) | set(vec2))
norm1 = math.sqrt(sum(val**2 for val in vec1.values()))
norm2 = math.sqrt(sum(val**2 for val in vec2.values()))
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)

98
pr/multiplexer.py Normal file
View File

@ -0,0 +1,98 @@
import threading
import queue
import time
import sys
from pr.ui import Colors
class TerminalMultiplexer:
def __init__(self, name, show_output=True):
self.name = name
self.show_output = show_output
self.stdout_buffer = []
self.stderr_buffer = []
self.stdout_queue = queue.Queue()
self.stderr_queue = queue.Queue()
self.active = True
self.lock = threading.Lock()
if self.show_output:
self.display_thread = threading.Thread(target=self._display_worker, daemon=True)
self.display_thread.start()
def _display_worker(self):
while self.active:
try:
line = self.stdout_queue.get(timeout=0.1)
if line:
sys.stdout.write(f"{Colors.GRAY}[{self.name}]{Colors.RESET} {line}")
sys.stdout.flush()
except queue.Empty:
pass
try:
line = self.stderr_queue.get(timeout=0.1)
if line:
sys.stderr.write(f"{Colors.YELLOW}[{self.name} err]{Colors.RESET} {line}")
sys.stderr.flush()
except queue.Empty:
pass
def write_stdout(self, data):
with self.lock:
self.stdout_buffer.append(data)
if self.show_output:
self.stdout_queue.put(data)
def write_stderr(self, data):
with self.lock:
self.stderr_buffer.append(data)
if self.show_output:
self.stderr_queue.put(data)
def get_stdout(self):
with self.lock:
return ''.join(self.stdout_buffer)
def get_stderr(self):
with self.lock:
return ''.join(self.stderr_buffer)
def get_all_output(self):
with self.lock:
return {
'stdout': ''.join(self.stdout_buffer),
'stderr': ''.join(self.stderr_buffer)
}
def close(self):
self.active = False
if hasattr(self, 'display_thread'):
self.display_thread.join(timeout=1)
_multiplexers = {}
_mux_counter = 0
_mux_lock = threading.Lock()
def create_multiplexer(name=None, show_output=True):
global _mux_counter
with _mux_lock:
if name is None:
_mux_counter += 1
name = f"process-{_mux_counter}"
mux = TerminalMultiplexer(name, show_output)
_multiplexers[name] = mux
return name, mux
def get_multiplexer(name):
return _multiplexers.get(name)
def close_multiplexer(name):
mux = _multiplexers.get(name)
if mux:
mux.close()
del _multiplexers[name]
def cleanup_all_multiplexers():
for mux in list(_multiplexers.values()):
mux.close()
_multiplexers.clear()

0
pr/plugins/__init__.py Normal file
View File

128
pr/plugins/loader.py Normal file
View File

@ -0,0 +1,128 @@
import os
import sys
import importlib.util
from typing import List, Dict, Callable, Any
from pr.core.logging import get_logger
logger = get_logger('plugins')
PLUGINS_DIR = os.path.expanduser("~/.pr/plugins")
class PluginLoader:
def __init__(self):
self.loaded_plugins = {}
self.plugin_tools = []
os.makedirs(PLUGINS_DIR, exist_ok=True)
def load_plugins(self) -> List[Dict]:
if not os.path.exists(PLUGINS_DIR):
logger.info("No plugins directory found")
return []
plugin_files = [f for f in os.listdir(PLUGINS_DIR) if f.endswith('.py')]
for plugin_file in plugin_files:
try:
self._load_plugin_file(plugin_file)
except Exception as e:
logger.error(f"Error loading plugin {plugin_file}: {e}")
return self.plugin_tools
def _load_plugin_file(self, filename: str):
plugin_path = os.path.join(PLUGINS_DIR, filename)
plugin_name = filename[:-3]
spec = importlib.util.spec_from_file_location(plugin_name, plugin_path)
if spec is None or spec.loader is None:
logger.error(f"Could not load spec for {filename}")
return
module = importlib.util.module_from_spec(spec)
sys.modules[plugin_name] = module
spec.loader.exec_module(module)
if hasattr(module, 'register_tools'):
tools = module.register_tools()
if isinstance(tools, list):
self.plugin_tools.extend(tools)
self.loaded_plugins[plugin_name] = module
logger.info(f"Loaded plugin: {plugin_name} ({len(tools)} tools)")
else:
logger.warning(f"Plugin {plugin_name} register_tools() did not return a list")
else:
logger.warning(f"Plugin {plugin_name} does not have register_tools() function")
def get_plugin_function(self, tool_name: str) -> Callable:
for plugin_name, module in self.loaded_plugins.items():
if hasattr(module, tool_name):
return getattr(module, tool_name)
raise ValueError(f"Plugin function not found: {tool_name}")
def list_loaded_plugins(self) -> List[str]:
return list(self.loaded_plugins.keys())
def create_example_plugin():
example_plugin = os.path.join(PLUGINS_DIR, 'example_plugin.py')
if os.path.exists(example_plugin):
return
example_code = '''"""
Example plugin for PR Assistant
This plugin demonstrates how to create custom tools.
"""
def my_custom_tool(argument: str) -> str:
"""
A custom tool that does something useful.
Args:
argument: Some input
Returns:
A result string
"""
return f"Custom tool processed: {argument}"
def register_tools():
"""
Register tools with the PR assistant.
Returns:
List of tool definitions
"""
return [
{
"type": "function",
"function": {
"name": "my_custom_tool",
"description": "A custom tool that processes input",
"parameters": {
"type": "object",
"properties": {
"argument": {
"type": "string",
"description": "The input to process"
}
},
"required": ["argument"]
}
}
}
]
'''
try:
os.makedirs(PLUGINS_DIR, exist_ok=True)
with open(example_plugin, 'w') as f:
f.write(example_code)
logger.info(f"Created example plugin at {example_plugin}")
except Exception as e:
logger.error(f"Error creating example plugin: {e}")

43
pr/research.md Normal file
View File

@ -0,0 +1,43 @@
# Research Overview: Additional Functionality for PR Assistant
## Overview of Current Application
The PR Assistant is a professional CLI AI assistant designed for autonomous execution of tasks. It integrates various tools including command execution, web fetching, database operations, filesystem management, and Python code execution. It features session management, logging, usage tracking, and a plugin system for extensibility.
## Potential New Features
Based on analysis of similar AI assistants and tool-using agents, here are researched ideas for additional functionality:
### 1. Multi-Modal Interfaces
- **Graphical User Interface (GUI)**: Develop a desktop app using frameworks like Electron or Tkinter to provide a user-friendly interface beyond CLI.
- **Web Interface**: Create a web-based dashboard for remote access and visualization of results.
- **Voice Input/Output**: Integrate speech recognition (e.g., via Google Speech API) and text-to-speech for hands-free interaction.
### 2. Enhanced Tool Ecosystem
- **Additional Built-in Tools**: Add tools for Git operations, email handling, calendar integration, image processing (e.g., OCR, generation via DALL-E), and real-time data feeds (weather, stocks).
- **API Integrations**: Connect to popular services like GitHub for repository management, Slack/Discord for notifications, or cloud storage (AWS S3, Google Drive).
- **Workflow Automation**: Implement chaining of tools for complex workflows, similar to Zapier or LangChain agents.
### 3. Advanced AI Capabilities
- **Multi-Agent Systems**: Allow multiple AI agents to collaborate on tasks, with role specialization (e.g., one for coding, one for research).
- **Long-Term Memory and Learning**: Implement persistent knowledge bases and fine-tuning on user interactions to improve responses over time.
- **Context Awareness**: Enhance context management with better summarization and retrieval of past conversations.
### 4. Productivity and Usability Enhancements
- **Export and Sharing**: Add options to export session results to formats like PDF, Markdown, or integrate with documentation tools (e.g., Notion, Confluence).
- **Scheduled Tasks**: Enable cron-like scheduling for autonomous task execution.
- **Multi-User Support**: Implement user accounts for shared access and collaboration features.
### 5. Security and Reliability
- **Sandboxing and Permissions**: Improve security with containerized tool execution and user-defined permission levels.
- **Error Recovery**: Add automatic retry mechanisms, fallback strategies, and detailed error reporting.
- **Audit Logging**: Enhance logging for compliance and debugging.
### 6. Plugin Ecosystem Expansion
- **Community Plugin Repository**: Create an online hub for user-contributed plugins.
- **Plugin Marketplace**: Allow users to rate and install plugins easily, with dependency management.
### 7. Performance Optimizations
- **Caching**: Implement caching for API calls and tool results to reduce latency.
- **Parallel Execution**: Enable concurrent tool usage for faster task completion.
- **Model Selection**: Expand support for multiple AI models and allow dynamic switching.
These features would position the PR Assistant as a more versatile and powerful tool, appealing to developers, researchers, and productivity enthusiasts. Implementation should prioritize backward compatibility and maintain the CLI-first approach while adding optional interfaces.

21
pr/tools/__init__.py Normal file
View File

@ -0,0 +1,21 @@
from pr.tools.base import get_tools_definition
from pr.tools.filesystem import (
read_file, write_file, list_directory, mkdir, chdir, getpwd, index_source_directory, search_replace
)
from pr.tools.command import run_command, run_command_interactive, tail_process, kill_process
from pr.tools.editor import open_editor, editor_insert_text, editor_replace_text, editor_search, close_editor
from pr.tools.database import db_set, db_get, db_query
from pr.tools.web import http_fetch, web_search, web_search_news
from pr.tools.python_exec import python_exec
from pr.tools.patch import apply_patch, create_diff
__all__ = [
'get_tools_definition',
'read_file', 'write_file', 'list_directory', 'mkdir', 'chdir', 'getpwd', 'index_source_directory', 'search_replace',
'open_editor', 'editor_insert_text', 'editor_replace_text', 'editor_search','close_editor',
'run_command', 'run_command_interactive',
'db_set', 'db_get', 'db_query',
'http_fetch', 'web_search', 'web_search_news',
'python_exec','tail_process', 'kill_process',
'apply_patch', 'create_diff'
]

444
pr/tools/base.py Normal file
View File

@ -0,0 +1,444 @@
def get_tools_definition():
return [
{
"type": "function",
"function": {
"name": "kill_process",
"description": "Terminate a background process by its PID. Use this to stop processes started with run_command that exceeded their timeout.",
"parameters": {
"type": "object",
"properties": {
"pid": {
"type": "integer",
"description": "The process ID returned by run_command when status is 'running'."
}
},
"required": ["pid"]
}
}
},
{
"type": "function",
"function": {
"name": "tail_process",
"description": "Monitor and retrieve output from a background process by its PID. Use this to check on processes started with run_command that exceeded their timeout.",
"parameters": {
"type": "object",
"properties": {
"pid": {
"type": "integer",
"description": "The process ID returned by run_command when status is 'running'."
},
"timeout": {
"type": "integer",
"description": "Maximum seconds to wait for process completion. Returns partial output if still running.",
"default": 30
}
},
"required": ["pid"]
}
}
},
{
"type": "function",
"function": {
"name": "http_fetch",
"description": "Fetch content from an HTTP URL",
"parameters": {
"type": "object",
"properties": {
"url": {"type": "string", "description": "The URL to fetch"},
"headers": {"type": "object", "description": "Optional HTTP headers"}
},
"required": ["url"]
}
}
},
{
"type": "function",
"function": {
"name": "run_command",
"description": "Execute a shell command and capture output. Returns immediately after timeout with PID if still running. Use tail_process to monitor or kill_process to terminate long-running commands.",
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string", "description": "The shell command to execute"},
"timeout": {"type": "integer", "description": "Maximum seconds to wait for completion", "default": 30}
},
"required": ["command"]
}
}
},
{
"type": "function",
"function": {
"name": "run_command_interactive",
"description": "Execute an interactive terminal command that requires user input or displays UI. The command runs in the user's terminal. Returns exit code only.",
"parameters": {
"type": "object",
"properties": {
"command": {"type": "string", "description": "The interactive command to execute (e.g., vim, nano, top)"}
},
"required": ["command"]
}
}
},
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read contents of a file",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"}
},
"required": ["filepath"]
}
}
},
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"content": {"type": "string", "description": "Content to write"}
},
"required": ["filepath", "content"]
}
}
},
{
"type": "function",
"function": {
"name": "list_directory",
"description": "List directory contents",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Directory path", "default": "."},
"recursive": {"type": "boolean", "description": "List recursively", "default": False}
}
}
}
},
{
"type": "function",
"function": {
"name": "mkdir",
"description": "Create a new directory",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path of the directory to create"}
},
"required": ["path"]
}
}
},
{
"type": "function",
"function": {
"name": "chdir",
"description": "Change the current working directory",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path to change to"}
},
"required": ["path"]
}
}
},
{
"type": "function",
"function": {
"name": "getpwd",
"description": "Get the current working directory",
"parameters": {"type": "object", "properties": {}}
}
},
{
"type": "function",
"function": {
"name": "db_set",
"description": "Set a key-value pair in the database",
"parameters": {
"type": "object",
"properties": {
"key": {"type": "string", "description": "The key"},
"value": {"type": "string", "description": "The value"}
},
"required": ["key", "value"]
}
}
},
{
"type": "function",
"function": {
"name": "db_get",
"description": "Get a value from the database",
"parameters": {
"type": "object",
"properties": {
"key": {"type": "string", "description": "The key"}
},
"required": ["key"]
}
}
},
{
"type": "function",
"function": {
"name": "db_query",
"description": "Execute a database query",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "SQL query"}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "web_search",
"description": "Perform a web search",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "web_search_news",
"description": "Perform a web search for news",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query for news"}
},
"required": ["query"]
}
}
},
{
"type": "function",
"function": {
"name": "python_exec",
"description": "Execute Python code",
"parameters": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "Python code to execute"}
},
"required": ["code"]
}
}
},
{
"type": "function",
"function": {
"name": "index_source_directory",
"description": "Index directory recursively and read all source files.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Path to index"}
},
"required": ["path"]
}
}
},
{
"type": "function",
"function": {
"name": "search_replace",
"description": "Search and replace text in a file",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"old_string": {"type": "string", "description": "String to replace"},
"new_string": {"type": "string", "description": "Replacement string"}
},
"required": ["filepath", "old_string", "new_string"]
}
}
},
{
"type": "function",
"function": {
"name": "apply_patch",
"description": "Apply a patch to a file, especially for source code",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file to patch"},
"patch_content": {"type": "string", "description": "The patch content as a string"}
},
"required": ["filepath", "patch_content"]
}
}
},
{
"type": "function",
"function": {
"name": "create_diff",
"description": "Create a unified diff between two files",
"parameters": {
"type": "object",
"properties": {
"file1": {"type": "string", "description": "Path to the first file"},
"file2": {"type": "string", "description": "Path to the second file"},
"fromfile": {"type": "string", "description": "Label for the first file", "default": "file1"},
"tofile": {"type": "string", "description": "Label for the second file", "default": "file2"}
},
"required": ["file1", "file2"]
}
}
},
{
"type": "function",
"function": {
"name": "open_editor",
"description": "Open the RPEditor for a file",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"}
},
"required": ["filepath"]
}
}
},
{
"type": "function",
"function": {
"name": "close_editor",
"description": "Close the RPEditor. Always close files when finished editing.",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"}
},
"required": ["filepath"]
}
}
},
{
"type": "function",
"function": {
"name": "editor_insert_text",
"description": "Insert text at cursor position in the editor",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"text": {"type": "string", "description": "Text to insert"},
"line": {"type": "integer", "description": "Line number (optional)"},
"col": {"type": "integer", "description": "Column number (optional)"}
},
"required": ["filepath", "text"]
}
}
},
{
"type": "function",
"function": {
"name": "editor_replace_text",
"description": "Replace text in a range",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"start_line": {"type": "integer", "description": "Start line"},
"start_col": {"type": "integer", "description": "Start column"},
"end_line": {"type": "integer", "description": "End line"},
"end_col": {"type": "integer", "description": "End column"},
"new_text": {"type": "string", "description": "New text"}
},
"required": ["filepath", "start_line", "start_col", "end_line", "end_col", "new_text"]
}
}
},
{
"type": "function",
"function": {
"name": "editor_search",
"description": "Search for a pattern in the file",
"parameters": {
"type": "object",
"properties": {
"filepath": {"type": "string", "description": "Path to the file"},
"pattern": {"type": "string", "description": "Regex pattern"},
"start_line": {"type": "integer", "description": "Start line", "default": 0}
},
"required": ["filepath", "pattern"]
}
}
},
{
"type": "function",
"function": {
"name": "display_file_diff",
"description": "Display a visual colored diff between two files with syntax highlighting and statistics",
"parameters": {
"type": "object",
"properties": {
"filepath1": {"type": "string", "description": "Path to the original file"},
"filepath2": {"type": "string", "description": "Path to the modified file"},
"format_type": {"type": "string", "description": "Display format: 'unified' or 'side-by-side'", "default": "unified"}
},
"required": ["filepath1", "filepath2"]
}
}
},
{
"type": "function",
"function": {
"name": "display_edit_summary",
"description": "Display a summary of all edit operations performed during the session",
"parameters": {
"type": "object",
"properties": {}
}
}
},
{
"type": "function",
"function": {
"name": "display_edit_timeline",
"description": "Display a timeline of all edit operations with details",
"parameters": {
"type": "object",
"properties": {
"show_content": {"type": "boolean", "description": "Show content previews", "default": False}
}
}
}
},
{
"type": "function",
"function": {
"name": "clear_edit_tracker",
"description": "Clear the edit tracker to start fresh",
"parameters": {
"type": "object",
"properties": {}
}
}
}
]

164
pr/tools/command.py Normal file
View File

@ -0,0 +1,164 @@
import os
import subprocess
import time
import select
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
_processes = {}
def _register_process(pid:int, process):
_processes[pid] = process
return _processes
def _get_process(pid:int):
return _processes.get(pid)
def kill_process(pid:int):
try:
process = _get_process(pid)
if process:
process.kill()
_processes.pop(pid)
mux_name = f"cmd-{pid}"
if get_multiplexer(mux_name):
close_multiplexer(mux_name)
return {"status": "success", "message": f"Process {pid} has been killed"}
else:
return {"status": "error", "error": f"Process {pid} not found"}
except Exception as e:
return {"status": "error", "error": str(e)}
def tail_process(pid: int, timeout: int = 30):
process = _get_process(pid)
if process:
mux_name = f"cmd-{pid}"
mux = get_multiplexer(mux_name)
if not mux:
mux_name, mux = create_multiplexer(mux_name, show_output=True)
try:
start_time = time.time()
timeout_duration = timeout
stdout_content = ""
stderr_content = ""
while True:
if process.poll() is not None:
remaining_stdout, remaining_stderr = process.communicate()
if remaining_stdout:
mux.write_stdout(remaining_stdout)
stdout_content += remaining_stdout
if remaining_stderr:
mux.write_stderr(remaining_stderr)
stderr_content += remaining_stderr
if pid in _processes:
_processes.pop(pid)
close_multiplexer(mux_name)
return {
"status": "success",
"stdout": stdout_content,
"stderr": stderr_content,
"returncode": process.returncode
}
if time.time() - start_time > timeout_duration:
return {
"status": "running",
"message": "Process is still running. Call tail_process again to continue monitoring.",
"stdout_so_far": stdout_content,
"stderr_so_far": stderr_content,
"pid": pid
}
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
for pipe in ready:
if pipe == process.stdout:
line = process.stdout.readline()
if line:
mux.write_stdout(line)
stdout_content += line
elif pipe == process.stderr:
line = process.stderr.readline()
if line:
mux.write_stderr(line)
stderr_content += line
except Exception as e:
return {"status": "error", "error": str(e)}
else:
return {"status": "error", "error": f"Process {pid} not found"}
def run_command(command, timeout=30):
mux_name = None
try:
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
_register_process(process.pid, process)
mux_name, mux = create_multiplexer(f"cmd-{process.pid}", show_output=True)
start_time = time.time()
timeout_duration = timeout
stdout_content = ""
stderr_content = ""
while True:
if process.poll() is not None:
remaining_stdout, remaining_stderr = process.communicate()
if remaining_stdout:
mux.write_stdout(remaining_stdout)
stdout_content += remaining_stdout
if remaining_stderr:
mux.write_stderr(remaining_stderr)
stderr_content += remaining_stderr
if process.pid in _processes:
_processes.pop(process.pid)
close_multiplexer(mux_name)
return {
"status": "success",
"stdout": stdout_content,
"stderr": stderr_content,
"returncode": process.returncode
}
if time.time() - start_time > timeout_duration:
return {
"status": "running",
"message": f"Process still running after {timeout}s timeout. Use tail_process({process.pid}) to monitor or kill_process({process.pid}) to terminate.",
"stdout_so_far": stdout_content,
"stderr_so_far": stderr_content,
"pid": process.pid,
"mux_name": mux_name
}
ready, _, _ = select.select([process.stdout, process.stderr], [], [], 0.1)
for pipe in ready:
if pipe == process.stdout:
line = process.stdout.readline()
if line:
mux.write_stdout(line)
stdout_content += line
elif pipe == process.stderr:
line = process.stderr.readline()
if line:
mux.write_stderr(line)
stderr_content += line
except Exception as e:
if mux_name:
close_multiplexer(mux_name)
return {"status": "error", "error": str(e)}
def run_command_interactive(command):
try:
return_code = os.system(command)
return {"status": "success", "returncode": return_code}
except Exception as e:
return {"status": "error", "error": str(e)}

47
pr/tools/database.py Normal file
View File

@ -0,0 +1,47 @@
import time
def db_set(key, value, db_conn):
if not db_conn:
return {"status": "error", "error": "Database not initialized"}
try:
cursor = db_conn.cursor()
cursor.execute("""INSERT OR REPLACE INTO kv_store (key, value, timestamp)
VALUES (?, ?, ?)""", (key, value, time.time()))
db_conn.commit()
return {"status": "success", "message": f"Set {key}"}
except Exception as e:
return {"status": "error", "error": str(e)}
def db_get(key, db_conn):
if not db_conn:
return {"status": "error", "error": "Database not initialized"}
try:
cursor = db_conn.cursor()
cursor.execute("SELECT value FROM kv_store WHERE key = ?", (key,))
result = cursor.fetchone()
if result:
return {"status": "success", "value": result[0]}
else:
return {"status": "error", "error": "Key not found"}
except Exception as e:
return {"status": "error", "error": str(e)}
def db_query(query, db_conn):
if not db_conn:
return {"status": "error", "error": "Database not initialized"}
try:
cursor = db_conn.cursor()
cursor.execute(query)
if query.strip().upper().startswith('SELECT'):
results = cursor.fetchall()
columns = [desc[0] for desc in cursor.description] if cursor.description else []
return {"status": "success", "columns": columns, "rows": results}
else:
db_conn.commit()
return {"status": "success", "rows_affected": cursor.rowcount}
except Exception as e:
return {"status": "error", "error": str(e)}

144
pr/tools/editor.py Normal file
View File

@ -0,0 +1,144 @@
from pr.editor import RPEditor
from pr.multiplexer import create_multiplexer, close_multiplexer, get_multiplexer
from ..ui.diff_display import display_diff, get_diff_stats
from ..ui.edit_feedback import track_edit, tracker
from ..tools.patch import display_content_diff
import os
import os.path
_editors = {}
def get_editor(filepath):
if filepath not in _editors:
_editors[filepath] = RPEditor(filepath)
return _editors[filepath]
def close_editor(filepath):
try:
path = os.path.expanduser(filepath)
editor = get_editor(path)
editor.close()
mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name)
if mux:
mux.write_stdout(f"Closed editor for: {path}\n")
close_multiplexer(mux_name)
return {"status": "success", "message": f"Editor closed for {path}"}
except Exception as e:
return {"status": "error", "error": str(e)}
def open_editor(filepath):
try:
path = os.path.expanduser(filepath)
editor = RPEditor(path)
editor.start()
mux_name = f"editor-{path}"
mux_name, mux = create_multiplexer(mux_name, show_output=True)
mux.write_stdout(f"Opened editor for: {path}\n")
return {"status": "success", "message": f"Editor opened for {path}", "mux_name": mux_name}
except Exception as e:
return {"status": "error", "error": str(e)}
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True):
try:
path = os.path.expanduser(filepath)
old_content = ""
if os.path.exists(path):
with open(path, 'r') as f:
old_content = f.read()
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
operation = track_edit('INSERT', filepath, start_pos=position, content=text)
tracker.mark_in_progress(operation)
editor = get_editor(path)
if line is not None and col is not None:
editor.move_cursor_to(line, col)
editor.insert_text(text)
editor.save_file()
mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name)
if mux:
location = f" at line {line}, col {col}" if line is not None and col is not None else ""
preview = text[:50] + "..." if len(text) > 50 else text
mux.write_stdout(f"Inserted text{location}: {repr(preview)}\n")
if show_diff and old_content:
with open(path, 'r') as f:
new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success":
mux.write_stdout(diff_result["visual_diff"] + "\n")
tracker.mark_completed(operation)
result = {"status": "success", "message": f"Inserted text in {path}"}
close_editor(filepath)
return result
except Exception as e:
if 'operation' in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True):
try:
path = os.path.expanduser(filepath)
old_content = ""
if os.path.exists(path):
with open(path, 'r') as f:
old_content = f.read()
start_pos = start_line * 1000 + start_col
end_pos = end_line * 1000 + end_col
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos,
content=new_text, old_content=old_content)
tracker.mark_in_progress(operation)
editor = get_editor(path)
editor.replace_text(start_line, start_col, end_line, end_col, new_text)
editor.save_file()
mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name)
if mux:
preview = new_text[:50] + "..." if len(new_text) > 50 else new_text
mux.write_stdout(f"Replaced text from ({start_line},{start_col}) to ({end_line},{end_col}): {repr(preview)}\n")
if show_diff and old_content:
with open(path, 'r') as f:
new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success":
mux.write_stdout(diff_result["visual_diff"] + "\n")
tracker.mark_completed(operation)
result = {"status": "success", "message": f"Replaced text in {path}"}
close_editor(filepath)
return result
except Exception as e:
if 'operation' in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def editor_search(filepath, pattern, start_line=0):
try:
path = os.path.expanduser(filepath)
editor = RPEditor(path)
results = editor.search(pattern, start_line)
mux_name = f"editor-{path}"
mux = get_multiplexer(mux_name)
if mux:
mux.write_stdout(f"Searched for pattern '{pattern}' from line {start_line}: {len(results)} matches\n")
result = {"status": "success", "results": results}
close_editor(filepath)
return result
except Exception as e:
return {"status": "error", "error": str(e)}

287
pr/tools/filesystem.py Normal file
View File

@ -0,0 +1,287 @@
import os
import hashlib
import time
from typing import Dict
from pr.editor import RPEditor
from ..ui.diff_display import display_diff, get_diff_stats
from ..ui.edit_feedback import track_edit, tracker
from ..tools.patch import display_content_diff
_id = 0
def get_uid():
global _id
_id += 3
return _id
def read_file(filepath, db_conn=None):
try:
path = os.path.expanduser(filepath)
with open(path, 'r') as f:
content = f.read()
if db_conn:
from pr.tools.database import db_set
db_set("read:" + path, "true", db_conn)
return {"status": "success", "content": content}
except Exception as e:
return {"status": "error", "error": str(e)}
def write_file(filepath, content, db_conn=None, show_diff=True):
try:
path = os.path.expanduser(filepath)
old_content = ""
is_new_file = not os.path.exists(path)
if not is_new_file and db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
if not is_new_file:
with open(path, 'r') as f:
old_content = f.read()
operation = track_edit('WRITE', filepath, content=content, old_content=old_content)
tracker.mark_in_progress(operation)
if show_diff and not is_new_file:
diff_result = display_content_diff(old_content, content, filepath)
if diff_result["status"] == "success":
print(diff_result["visual_diff"])
editor = RPEditor(path)
editor.set_text(content)
editor.save_file()
if os.path.exists(path) and db_conn:
try:
cursor = db_conn.cursor()
file_hash = hashlib.md5(old_content.encode()).hexdigest()
cursor.execute("SELECT MAX(version) FROM file_versions WHERE filepath = ?", (filepath,))
result = cursor.fetchone()
version = (result[0] + 1) if result[0] else 1
cursor.execute("""INSERT INTO file_versions (filepath, content, hash, timestamp, version)
VALUES (?, ?, ?, ?, ?)""",
(filepath, old_content, file_hash, time.time(), version))
db_conn.commit()
except Exception:
pass
tracker.mark_completed(operation)
message = f"File written to {path}"
if show_diff and not is_new_file:
stats = get_diff_stats(old_content, content)
message += f" ({stats['insertions']}+ {stats['deletions']}-)"
return {"status": "success", "message": message}
except Exception as e:
if 'operation' in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def list_directory(path=".", recursive=False):
try:
path = os.path.expanduser(path)
items = []
if recursive:
for root, dirs, files in os.walk(path):
for name in files:
item_path = os.path.join(root, name)
items.append({"path": item_path, "type": "file", "size": os.path.getsize(item_path)})
for name in dirs:
items.append({"path": os.path.join(root, name), "type": "directory"})
else:
for item in os.listdir(path):
item_path = os.path.join(path, item)
items.append({
"name": item,
"type": "directory" if os.path.isdir(item_path) else "file",
"size": os.path.getsize(item_path) if os.path.isfile(item_path) else None
})
return {"status": "success", "items": items}
except Exception as e:
return {"status": "error", "error": str(e)}
def mkdir(path):
try:
os.makedirs(os.path.expanduser(path), exist_ok=True)
return {"status": "success", "message": f"Directory created at {path}"}
except Exception as e:
return {"status": "error", "error": str(e)}
def chdir(path):
try:
os.chdir(os.path.expanduser(path))
return {"status": "success", "new_path": os.getcwd()}
except Exception as e:
return {"status": "error", "error": str(e)}
def getpwd():
try:
return {"status": "success", "path": os.getcwd()}
except Exception as e:
return {"status": "error", "error": str(e)}
def index_source_directory(path):
extensions = [
".py", ".js", ".ts", ".java", ".cpp", ".c", ".h", ".hpp",
".html", ".css", ".json", ".xml", ".md", ".sh", ".rb", ".go"
]
source_files = []
try:
for root, _, files in os.walk(os.path.expanduser(path)):
for file in files:
if any(file.endswith(ext) for ext in extensions):
filepath = os.path.join(root, file)
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
source_files.append({
"path": filepath,
"content": content
})
except Exception:
continue
return {"status": "success", "indexed_files": source_files}
except Exception as e:
return {"status": "error", "error": str(e)}
def search_replace(filepath, old_string, new_string, db_conn=None):
try:
path = os.path.expanduser(filepath)
if not os.path.exists(path):
return {"status": "error", "error": "File does not exist"}
if db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
with open(path, 'r') as f:
content = f.read()
content = content.replace(old_string, new_string)
with open(path, 'w') as f:
f.write(content)
return {"status": "success", "message": f"Replaced '{old_string}' with '{new_string}' in {path}"}
except Exception as e:
return {"status": "error", "error": str(e)}
_editors = {}
def get_editor(filepath):
if filepath not in _editors:
_editors[filepath] = RPEditor(filepath)
return _editors[filepath]
def close_editor(filepath):
try:
path = os.path.expanduser(filepath)
editor = get_editor(path)
editor.close()
return {"status": "success", "message": f"Editor closed for {path}"}
except Exception as e:
return {"status": "error", "error": str(e)}
def open_editor(filepath):
try:
path = os.path.expanduser(filepath)
editor = RPEditor(path)
editor.start()
return {"status": "success", "message": f"Editor opened for {path}"}
except Exception as e:
return {"status": "error", "error": str(e)}
def editor_insert_text(filepath, text, line=None, col=None, show_diff=True, db_conn=None):
try:
path = os.path.expanduser(filepath)
if db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
old_content = ""
if os.path.exists(path):
with open(path, 'r') as f:
old_content = f.read()
position = (line if line is not None else 0) * 1000 + (col if col is not None else 0)
operation = track_edit('INSERT', filepath, start_pos=position, content=text)
tracker.mark_in_progress(operation)
editor = get_editor(path)
if line is not None and col is not None:
editor.move_cursor_to(line, col)
editor.insert_text(text)
editor.save_file()
if show_diff and old_content:
with open(path, 'r') as f:
new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success":
print(diff_result["visual_diff"])
tracker.mark_completed(operation)
return {"status": "success", "message": f"Inserted text in {path}"}
except Exception as e:
if 'operation' in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def editor_replace_text(filepath, start_line, start_col, end_line, end_col, new_text, show_diff=True, db_conn=None):
try:
path = os.path.expanduser(filepath)
if db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
old_content = ""
if os.path.exists(path):
with open(path, 'r') as f:
old_content = f.read()
start_pos = start_line * 1000 + start_col
end_pos = end_line * 1000 + end_col
operation = track_edit('REPLACE', filepath, start_pos=start_pos, end_pos=end_pos,
content=new_text, old_content=old_content)
tracker.mark_in_progress(operation)
editor = get_editor(path)
editor.replace_text(start_line, start_col, end_line, end_col, new_text)
editor.save_file()
if show_diff and old_content:
with open(path, 'r') as f:
new_content = f.read()
diff_result = display_content_diff(old_content, new_content, filepath)
if diff_result["status"] == "success":
print(diff_result["visual_diff"])
tracker.mark_completed(operation)
return {"status": "success", "message": f"Replaced text in {path}"}
except Exception as e:
if 'operation' in locals():
tracker.mark_failed(operation)
return {"status": "error", "error": str(e)}
def display_edit_summary():
from ..ui.edit_feedback import display_edit_summary
return display_edit_summary()
def display_edit_timeline(show_content=False):
from ..ui.edit_feedback import display_edit_timeline
return display_edit_timeline(show_content)
def clear_edit_tracker():
from ..ui.edit_feedback import clear_tracker
clear_tracker()
return {"status": "success", "message": "Edit tracker cleared"}

91
pr/tools/patch.py Normal file
View File

@ -0,0 +1,91 @@
import os
import tempfile
import subprocess
import difflib
from ..ui.diff_display import display_diff, get_diff_stats, DiffDisplay
def apply_patch(filepath, patch_content, db_conn=None):
try:
path = os.path.expanduser(filepath)
if db_conn:
from pr.tools.database import db_get
read_status = db_get("read:" + path, db_conn)
if read_status.get("status") != "success" or read_status.get("value") != "true":
return {"status": "error", "error": "File must be read before writing. Please read the file first."}
# Write patch to temp file
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.patch') as f:
f.write(patch_content)
patch_file = f.name
# Run patch command
result = subprocess.run(['patch', path, patch_file], capture_output=True, text=True, cwd=os.path.dirname(path))
os.unlink(patch_file)
if result.returncode == 0:
return {"status": "success", "output": result.stdout.strip()}
else:
return {"status": "error", "error": result.stderr.strip()}
except Exception as e:
return {"status": "error", "error": str(e)}
def create_diff(file1, file2, fromfile='file1', tofile='file2', visual=False, format_type='unified'):
try:
path1 = os.path.expanduser(file1)
path2 = os.path.expanduser(file2)
with open(path1, 'r') as f1, open(path2, 'r') as f2:
content1 = f1.read()
content2 = f2.read()
if visual:
visual_diff = display_diff(content1, content2, fromfile, format_type)
stats = get_diff_stats(content1, content2)
lines1 = content1.splitlines(keepends=True)
lines2 = content2.splitlines(keepends=True)
plain_diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
return {
"status": "success",
"diff": ''.join(plain_diff),
"visual_diff": visual_diff,
"stats": stats
}
else:
lines1 = content1.splitlines(keepends=True)
lines2 = content2.splitlines(keepends=True)
diff = list(difflib.unified_diff(lines1, lines2, fromfile=fromfile, tofile=tofile))
return {"status": "success", "diff": ''.join(diff)}
except Exception as e:
return {"status": "error", "error": str(e)}
def display_file_diff(filepath1, filepath2, format_type='unified', context_lines=3):
try:
path1 = os.path.expanduser(filepath1)
path2 = os.path.expanduser(filepath2)
with open(path1, 'r') as f1:
old_content = f1.read()
with open(path2, 'r') as f2:
new_content = f2.read()
visual_diff = display_diff(old_content, new_content, filepath1, format_type)
stats = get_diff_stats(old_content, new_content)
return {
"status": "success",
"visual_diff": visual_diff,
"stats": stats
}
except Exception as e:
return {"status": "error", "error": str(e)}
def display_content_diff(old_content, new_content, filename='file', format_type='unified'):
try:
visual_diff = display_diff(old_content, new_content, filename, format_type)
stats = get_diff_stats(old_content, new_content)
return {
"status": "success",
"visual_diff": visual_diff,
"stats": stats
}
except Exception as e:
return {"status": "error", "error": str(e)}

13
pr/tools/python_exec.py Normal file
View File

@ -0,0 +1,13 @@
import traceback
from io import StringIO
import contextlib
def python_exec(code, python_globals):
try:
output = StringIO()
with contextlib.redirect_stdout(output):
exec(code, python_globals)
return {"status": "success", "output": output.getvalue()}
except Exception as e:
return {"status": "error", "error": str(e), "traceback": traceback.format_exc()}

36
pr/tools/web.py Normal file
View File

@ -0,0 +1,36 @@
import urllib.request
import urllib.parse
import urllib.error
import json
def http_fetch(url, headers=None):
try:
req = urllib.request.Request(url)
if headers:
for key, value in headers.items():
req.add_header(key, value)
with urllib.request.urlopen(req) as response:
content = response.read().decode('utf-8')
return {"status": "success", "content": content[:10000]}
except Exception as e:
return {"status": "error", "error": str(e)}
def _perform_search(base_url, query, params=None):
try:
full_url = f"https://static.molodetz.nl/search.cgi?query={query}"
with urllib.request.urlopen(full_url) as response:
content = response.read().decode('utf-8')
return {"status": "success", "content": json.loads(content)}
except Exception as e:
return {"status": "error", "error": str(e)}
def web_search(query):
base_url = "https://search.molodetz.nl/search"
return _perform_search(base_url, query)
def web_search_news(query):
base_url = "https://search.molodetz.nl/search"
return _perform_search(base_url, query)

5
pr/ui/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from pr.ui.colors import Colors
from pr.ui.rendering import highlight_code, render_markdown
from pr.ui.display import display_tool_call, print_autonomous_header
__all__ = ['Colors', 'highlight_code', 'render_markdown', 'display_tool_call', 'print_autonomous_header']

14
pr/ui/colors.py Normal file
View File

@ -0,0 +1,14 @@
class Colors:
RESET = '\033[0m'
BOLD = '\033[1m'
RED = '\033[91m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
MAGENTA = '\033[95m'
CYAN = '\033[96m'
GRAY = '\033[90m'
WHITE = '\033[97m'
BG_BLUE = '\033[44m'
BG_GREEN = '\033[42m'
BG_RED = '\033[41m'

199
pr/ui/diff_display.py Normal file
View File

@ -0,0 +1,199 @@
import difflib
from typing import List, Tuple, Dict, Optional
from .colors import Colors
class DiffStats:
def __init__(self):
self.insertions = 0
self.deletions = 0
self.modifications = 0
self.files_changed = 0
@property
def total_changes(self):
return self.insertions + self.deletions
def __str__(self):
return f"{self.files_changed} file(s) changed, {self.insertions} insertions(+), {self.deletions} deletions(-)"
class DiffLine:
def __init__(self, line_type: str, content: str, old_line_num: Optional[int] = None,
new_line_num: Optional[int] = None):
self.line_type = line_type
self.content = content
self.old_line_num = old_line_num
self.new_line_num = new_line_num
def format(self, show_line_nums: bool = True) -> str:
color = {
'add': Colors.GREEN,
'delete': Colors.RED,
'context': Colors.GRAY,
'header': Colors.CYAN,
'stats': Colors.BLUE
}.get(self.line_type, Colors.RESET)
prefix = {
'add': '+ ',
'delete': '- ',
'context': ' ',
'header': '',
'stats': ''
}.get(self.line_type, ' ')
if show_line_nums and self.line_type in ('add', 'delete', 'context'):
old_num = str(self.old_line_num) if self.old_line_num else ' '
new_num = str(self.new_line_num) if self.new_line_num else ' '
line_num_str = f"{Colors.YELLOW}{old_num:>4} {new_num:>4}{Colors.RESET} "
else:
line_num_str = ''
return f"{line_num_str}{color}{prefix}{self.content}{Colors.RESET}"
class DiffDisplay:
def __init__(self, context_lines: int = 3):
self.context_lines = context_lines
def create_diff(self, old_content: str, new_content: str,
filename: str = "file") -> Tuple[List[DiffLine], DiffStats]:
old_lines = old_content.splitlines(keepends=True)
new_lines = new_content.splitlines(keepends=True)
diff_lines = []
stats = DiffStats()
stats.files_changed = 1
diff = difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{filename}",
tofile=f"b/{filename}",
n=self.context_lines
)
old_line_num = 0
new_line_num = 0
for line in diff:
if line.startswith('---') or line.startswith('+++'):
diff_lines.append(DiffLine('header', line.rstrip()))
elif line.startswith('@@'):
diff_lines.append(DiffLine('header', line.rstrip()))
old_line_num, new_line_num = self._parse_hunk_header(line)
elif line.startswith('+'):
stats.insertions += 1
diff_lines.append(DiffLine('add', line[1:].rstrip(), None, new_line_num))
new_line_num += 1
elif line.startswith('-'):
stats.deletions += 1
diff_lines.append(DiffLine('delete', line[1:].rstrip(), old_line_num, None))
old_line_num += 1
elif line.startswith(' '):
diff_lines.append(DiffLine('context', line[1:].rstrip(), old_line_num, new_line_num))
old_line_num += 1
new_line_num += 1
stats.modifications = min(stats.insertions, stats.deletions)
return diff_lines, stats
def _parse_hunk_header(self, header: str) -> Tuple[int, int]:
try:
parts = header.split('@@')[1].strip().split()
old_start = int(parts[0].split(',')[0].replace('-', ''))
new_start = int(parts[1].split(',')[0].replace('+', ''))
return old_start, new_start
except (IndexError, ValueError):
return 0, 0
def render_diff(self, diff_lines: List[DiffLine], stats: DiffStats,
show_line_nums: bool = True, show_stats: bool = True) -> str:
output = []
if show_stats:
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}DIFF SUMMARY{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
output.append(f"{Colors.BLUE}{stats}{Colors.RESET}\n")
for line in diff_lines:
output.append(line.format(show_line_nums))
if show_stats:
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output)
def display_file_diff(self, old_content: str, new_content: str,
filename: str = "file", show_line_nums: bool = True) -> str:
diff_lines, stats = self.create_diff(old_content, new_content, filename)
if not diff_lines:
return f"{Colors.GRAY}No changes detected{Colors.RESET}"
return self.render_diff(diff_lines, stats, show_line_nums)
def display_side_by_side(self, old_content: str, new_content: str,
filename: str = "file", width: int = 80) -> str:
old_lines = old_content.splitlines()
new_lines = new_content.splitlines()
matcher = difflib.SequenceMatcher(None, old_lines, new_lines)
output = []
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}SIDE-BY-SIDE COMPARISON: {filename}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
half_width = (width - 5) // 2
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == 'equal':
for i, (old_line, new_line) in enumerate(zip(old_lines[i1:i2], new_lines[j1:j2])):
old_display = old_line[:half_width].ljust(half_width)
new_display = new_line[:half_width].ljust(half_width)
output.append(f"{Colors.GRAY}{old_display} | {new_display}{Colors.RESET}")
elif tag == 'replace':
max_lines = max(i2 - i1, j2 - j1)
for i in range(max_lines):
old_line = old_lines[i1 + i] if i1 + i < i2 else ""
new_line = new_lines[j1 + i] if j1 + i < j2 else ""
old_display = old_line[:half_width].ljust(half_width)
new_display = new_line[:half_width].ljust(half_width)
output.append(f"{Colors.RED}{old_display}{Colors.RESET} | {Colors.GREEN}{new_display}{Colors.RESET}")
elif tag == 'delete':
for old_line in old_lines[i1:i2]:
old_display = old_line[:half_width].ljust(half_width)
output.append(f"{Colors.RED}{old_display} | {' ' * half_width}{Colors.RESET}")
elif tag == 'insert':
for new_line in new_lines[j1:j2]:
new_display = new_line[:half_width].ljust(half_width)
output.append(f"{' ' * half_width} | {Colors.GREEN}{new_display}{Colors.RESET}")
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * width}{Colors.RESET}\n")
return '\n'.join(output)
def display_diff(old_content: str, new_content: str, filename: str = "file",
format_type: str = "unified", context_lines: int = 3) -> str:
displayer = DiffDisplay(context_lines)
if format_type == "side-by-side":
return displayer.display_side_by_side(old_content, new_content, filename)
else:
return displayer.display_file_diff(old_content, new_content, filename)
def get_diff_stats(old_content: str, new_content: str) -> Dict[str, int]:
displayer = DiffDisplay()
_, stats = displayer.create_diff(old_content, new_content)
return {
'insertions': stats.insertions,
'deletions': stats.deletions,
'modifications': stats.modifications,
'total_changes': stats.total_changes,
'files_changed': stats.files_changed
}

46
pr/ui/display.py Normal file
View File

@ -0,0 +1,46 @@
import json
from typing import Dict, Any
from pr.ui.colors import Colors
def display_tool_call(tool_name, arguments, status="running", result=None):
status_icons = {
"running": ("", Colors.YELLOW),
"success": ("", Colors.GREEN),
"error": ("", Colors.RED)
}
icon, color = status_icons.get(status, ("", Colors.WHITE))
print(f"\n{Colors.BOLD}{'' * 80}{Colors.RESET}")
print(f"{color}{icon} {Colors.BOLD}{Colors.CYAN}TOOL: {tool_name}{Colors.RESET}")
print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}")
if arguments:
print(f"{Colors.YELLOW}Parameters:{Colors.RESET}")
for key, value in arguments.items():
value_str = str(value)
if len(value_str) > 100:
value_str = value_str[:100] + "..."
print(f" {Colors.CYAN}{key}:{Colors.RESET} {value_str}")
if result is not None and status != "running":
print(f"\n{Colors.YELLOW}Result:{Colors.RESET}")
result_str = json.dumps(result, indent=2) if isinstance(result, dict) else str(result)
if len(result_str) > 500:
result_str = result_str[:500] + f"\n{Colors.GRAY}... (truncated){Colors.RESET}"
if status == "success":
print(f"{Colors.GREEN}{result_str}{Colors.RESET}")
elif status == "error":
print(f"{Colors.RED}{result_str}{Colors.RESET}")
else:
print(result_str)
print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}\n")
def print_autonomous_header(task):
print(f"{Colors.BOLD}Task:{Colors.RESET} {task}")
print(f"{Colors.GRAY}r will work continuously until the task is complete.{Colors.RESET}")
print(f"{Colors.GRAY}Press Ctrl+C twice to interrupt.{Colors.RESET}\n")
print(f"{Colors.BOLD}{'' * 80}{Colors.RESET}\n")

198
pr/ui/edit_feedback.py Normal file
View File

@ -0,0 +1,198 @@
from typing import List, Dict, Optional
from datetime import datetime
from .colors import Colors
from .progress import ProgressBar
class EditOperation:
def __init__(self, op_type: str, filepath: str, start_pos: int = 0,
end_pos: int = 0, content: str = "", old_content: str = ""):
self.op_type = op_type
self.filepath = filepath
self.start_pos = start_pos
self.end_pos = end_pos
self.content = content
self.old_content = old_content
self.timestamp = datetime.now()
self.status = "pending"
def format_operation(self) -> str:
op_colors = {
'INSERT': Colors.GREEN,
'REPLACE': Colors.YELLOW,
'DELETE': Colors.RED,
'WRITE': Colors.BLUE
}
color = op_colors.get(self.op_type, Colors.RESET)
status_icon = {
'pending': '',
'in_progress': '',
'completed': '',
'failed': ''
}.get(self.status, '')
return f"{color}{status_icon} [{self.op_type}]{Colors.RESET} {self.filepath}"
def format_details(self, show_content: bool = True) -> str:
output = [self.format_operation()]
if self.op_type in ('INSERT', 'REPLACE'):
output.append(f" {Colors.GRAY}Position: {self.start_pos}-{self.end_pos}{Colors.RESET}")
if show_content:
if self.old_content:
lines = self.old_content.split('\n')
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '')
output.append(f" {Colors.RED}- {preview}{Colors.RESET}")
if self.content:
lines = self.content.split('\n')
preview = lines[0][:60] + ('...' if len(lines[0]) > 60 or len(lines) > 1 else '')
output.append(f" {Colors.GREEN}+ {preview}{Colors.RESET}")
return '\n'.join(output)
class EditTracker:
def __init__(self):
self.operations: List[EditOperation] = []
self.current_file: Optional[str] = None
def add_operation(self, op_type: str, filepath: str, **kwargs) -> EditOperation:
op = EditOperation(op_type, filepath, **kwargs)
self.operations.append(op)
self.current_file = filepath
return op
def mark_in_progress(self, operation: EditOperation):
operation.status = "in_progress"
def mark_completed(self, operation: EditOperation):
operation.status = "completed"
def mark_failed(self, operation: EditOperation):
operation.status = "failed"
def get_stats(self) -> Dict[str, int]:
stats = {
'total': len(self.operations),
'completed': sum(1 for op in self.operations if op.status == 'completed'),
'pending': sum(1 for op in self.operations if op.status == 'pending'),
'in_progress': sum(1 for op in self.operations if op.status == 'in_progress'),
'failed': sum(1 for op in self.operations if op.status == 'failed')
}
return stats
def get_completion_percentage(self) -> float:
if not self.operations:
return 0.0
stats = self.get_stats()
return (stats['completed'] / stats['total']) * 100
def display_progress(self) -> str:
if not self.operations:
return f"{Colors.GRAY}No edit operations tracked{Colors.RESET}"
output = []
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT OPERATIONS PROGRESS{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
stats = self.get_stats()
completion = self.get_completion_percentage()
progress_bar = ProgressBar(total=stats['total'], width=40)
progress_bar.current = stats['completed']
bar_display = progress_bar._get_bar_display()
output.append(f"Progress: {bar_display}")
output.append(f"{Colors.BLUE}Total: {stats['total']}, Completed: {stats['completed']}, "
f"Pending: {stats['pending']}, Failed: {stats['failed']}{Colors.RESET}\n")
output.append(f"{Colors.BOLD}Recent Operations:{Colors.RESET}")
for i, op in enumerate(self.operations[-5:], 1):
output.append(f"{i}. {op.format_operation()}")
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output)
def display_timeline(self, show_content: bool = False) -> str:
if not self.operations:
return f"{Colors.GRAY}No edit operations tracked{Colors.RESET}"
output = []
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}EDIT TIMELINE{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
for i, op in enumerate(self.operations, 1):
timestamp = op.timestamp.strftime("%H:%M:%S")
output.append(f"{Colors.GRAY}[{timestamp}]{Colors.RESET} {i}.")
output.append(op.format_details(show_content))
output.append("")
stats = self.get_stats()
output.append(f"{Colors.BOLD}Summary:{Colors.RESET}")
output.append(f"{Colors.BLUE}Total operations: {stats['total']}, "
f"Completed: {stats['completed']}, Failed: {stats['failed']}{Colors.RESET}")
output.append(f"\n{Colors.BOLD}{Colors.BLUE}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output)
def display_summary(self) -> str:
if not self.operations:
return f"{Colors.GRAY}No edits to summarize{Colors.RESET}"
stats = self.get_stats()
files_modified = len(set(op.filepath for op in self.operations))
output = []
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.GREEN}EDIT SUMMARY{Colors.RESET}")
output.append(f"{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n")
output.append(f"{Colors.GREEN}Files Modified: {files_modified}{Colors.RESET}")
output.append(f"{Colors.GREEN}Total Operations: {stats['total']}{Colors.RESET}")
output.append(f"{Colors.GREEN}Successful: {stats['completed']}{Colors.RESET}")
if stats['failed'] > 0:
output.append(f"{Colors.RED}Failed: {stats['failed']}{Colors.RESET}")
output.append(f"\n{Colors.BOLD}Operations by Type:{Colors.RESET}")
op_types = {}
for op in self.operations:
op_types[op.op_type] = op_types.get(op.op_type, 0) + 1
for op_type, count in sorted(op_types.items()):
output.append(f" {op_type}: {count}")
output.append(f"\n{Colors.BOLD}{Colors.GREEN}{'=' * 60}{Colors.RESET}\n")
return '\n'.join(output)
def clear(self):
self.operations.clear()
self.current_file = None
tracker = EditTracker()
def track_edit(op_type: str, filepath: str, **kwargs) -> EditOperation:
return tracker.add_operation(op_type, filepath, **kwargs)
def display_edit_progress() -> str:
return tracker.display_progress()
def display_edit_timeline(show_content: bool = False) -> str:
return tracker.display_timeline(show_content)
def display_edit_summary() -> str:
return tracker.display_summary()
def clear_tracker():
tracker.clear()

69
pr/ui/output.py Normal file
View File

@ -0,0 +1,69 @@
import json
import sys
from typing import Any, Dict, List
from datetime import datetime
class OutputFormatter:
def __init__(self, format_type: str = 'text', quiet: bool = False):
self.format_type = format_type
self.quiet = quiet
def output(self, data: Any, message_type: str = 'response'):
if self.quiet and message_type not in ['error', 'result']:
return
if self.format_type == 'json':
self._output_json(data, message_type)
elif self.format_type == 'structured':
self._output_structured(data, message_type)
else:
self._output_text(data, message_type)
def _output_json(self, data: Any, message_type: str):
output = {
'type': message_type,
'timestamp': datetime.now().isoformat(),
'data': data
}
print(json.dumps(output, indent=2))
def _output_structured(self, data: Any, message_type: str):
if isinstance(data, dict):
for key, value in data.items():
print(f"{key}: {value}")
elif isinstance(data, list):
for item in data:
print(f"- {item}")
else:
print(data)
def _output_text(self, data: Any, message_type: str):
if isinstance(data, (dict, list)):
print(json.dumps(data, indent=2))
else:
print(data)
def error(self, message: str):
if self.format_type == 'json':
self._output_json({'error': message}, 'error')
else:
print(f"Error: {message}", file=sys.stderr)
def success(self, message: str):
if not self.quiet:
if self.format_type == 'json':
self._output_json({'success': message}, 'success')
else:
print(message)
def info(self, message: str):
if not self.quiet:
if self.format_type == 'json':
self._output_json({'info': message}, 'info')
else:
print(message)
def result(self, data: Any):
self.output(data, 'result')

76
pr/ui/progress.py Normal file
View File

@ -0,0 +1,76 @@
import sys
import time
import threading
class ProgressIndicator:
def __init__(self, message: str = "Working", show: bool = True):
self.message = message
self.show = show
self.running = False
self.thread = None
def __enter__(self):
if self.show:
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.show:
self.stop()
def start(self):
self.running = True
self.thread = threading.Thread(target=self._animate, daemon=True)
self.thread.start()
def stop(self):
if self.running:
self.running = False
if self.thread:
self.thread.join(timeout=1.0)
sys.stdout.write('\r' + ' ' * (len(self.message) + 10) + '\r')
sys.stdout.flush()
def _animate(self):
spinner = ['', '', '', '', '', '', '', '', '', '']
idx = 0
while self.running:
sys.stdout.write(f'\r{spinner[idx]} {self.message}...')
sys.stdout.flush()
idx = (idx + 1) % len(spinner)
time.sleep(0.1)
class ProgressBar:
def __init__(self, total: int, description: str = "Progress", width: int = 40):
self.total = total
self.description = description
self.width = width
self.current = 0
def update(self, amount: int = 1):
self.current += amount
self._display()
def _display(self):
if self.total == 0:
percent = 100
else:
percent = int((self.current / self.total) * 100)
filled = int((self.current / self.total) * self.width) if self.total > 0 else self.width
bar = '' * filled + '' * (self.width - filled)
sys.stdout.write(f'\r{self.description}: |{bar}| {percent}% ({self.current}/{self.total})')
sys.stdout.flush()
if self.current >= self.total:
sys.stdout.write('\n')
def finish(self):
self.current = self.total
self._display()

90
pr/ui/rendering.py Normal file
View File

@ -0,0 +1,90 @@
import re
from pr.ui.colors import Colors
from pr.config import LANGUAGE_KEYWORDS
def highlight_code(code, language=None, syntax_highlighting=True):
if not syntax_highlighting:
return code
if not language:
if 'def ' in code or 'import ' in code:
language = 'python'
elif 'function ' in code or 'const ' in code:
language = 'javascript'
elif 'public ' in code or 'class ' in code:
language = 'java'
if language and language in LANGUAGE_KEYWORDS:
keywords = LANGUAGE_KEYWORDS[language]
for keyword in keywords:
pattern = r'\b' + re.escape(keyword) + r'\b'
code = re.sub(pattern, f"{Colors.BLUE}{keyword}{Colors.RESET}", code)
code = re.sub(r'"([^"]*)"', f'{Colors.GREEN}"\\1"{Colors.RESET}', code)
code = re.sub(r"'([^']*)'", f"{Colors.GREEN}'\\1'{Colors.RESET}", code)
code = re.sub(r'#(.*)$', f'{Colors.GRAY}#\\1{Colors.RESET}', code, flags=re.MULTILINE)
code = re.sub(r'//(.*)$', f'{Colors.GRAY}//\\1{Colors.RESET}', code, flags=re.MULTILINE)
return code
def render_markdown(text, syntax_highlighting=True):
if not syntax_highlighting:
return text
code_blocks = []
def extract_code_block(match):
lang = match.group(1) or ''
code = match.group(2)
highlighted_code = highlight_code(code.strip('\n'), lang, syntax_highlighting)
placeholder = f"%%CODEBLOCK{len(code_blocks)}%%"
full_block = f'{Colors.GRAY}```{lang}{Colors.RESET}\n{highlighted_code}\n{Colors.GRAY}```{Colors.RESET}'
code_blocks.append(full_block)
return placeholder
text = re.sub(r'```(\w*)\n(.*?)\n?```', extract_code_block, text, flags=re.DOTALL)
inline_codes = []
def extract_inline_code(match):
code = match.group(1)
placeholder = f"%%INLINECODE{len(inline_codes)}%%"
inline_codes.append(f'{Colors.YELLOW}{code}{Colors.RESET}')
return placeholder
text = re.sub(r'`([^`]+)`', extract_inline_code, text)
lines = text.split('\n')
processed_lines = []
for line in lines:
if line.startswith('### '):
line = f'{Colors.BOLD}{Colors.GREEN}{line[4:]}{Colors.RESET}'
elif line.startswith('## '):
line = f'{Colors.BOLD}{Colors.BLUE}{line[3:]}{Colors.RESET}'
elif line.startswith('# '):
line = f'{Colors.BOLD}{Colors.MAGENTA}{line[2:]}{Colors.RESET}'
elif line.startswith('> '):
line = f'{Colors.CYAN}> {line[2:]}{Colors.RESET}'
elif re.match(r'^\s*[\*\-\+]\s', line):
match = re.match(r'^(\s*)([\*\-\+])(\s+.*)', line)
if match:
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
elif re.match(r'^\s*\d+\.\s', line):
match = re.match(r'^(\s*)(\d+\.)(\s+.*)', line)
if match:
line = f"{match.group(1)}{Colors.YELLOW}{match.group(2)}{Colors.RESET}{match.group(3)}"
processed_lines.append(line)
text = '\n'.join(processed_lines)
text = re.sub(r'\[(.*?)\]\((.*?)\)', f'{Colors.BLUE}\\1{Colors.RESET}{Colors.GRAY}(\\2){Colors.RESET}', text)
text = re.sub(r'~~(.*?)~~', f'{Colors.GRAY}\\1{Colors.RESET}', text)
text = re.sub(r'\*\*(.*?)\*\*', f'{Colors.BOLD}\\1{Colors.RESET}', text)
text = re.sub(r'__(.*?)__', f'{Colors.BOLD}\\1{Colors.RESET}', text)
text = re.sub(r'\*(.*?)\*', f'{Colors.CYAN}\\1{Colors.RESET}', text)
text = re.sub(r'_(.*?)_', f'{Colors.CYAN}\\1{Colors.RESET}', text)
for i, code in enumerate(inline_codes):
text = text.replace(f'%%INLINECODE{i}%%', code)
for i, block in enumerate(code_blocks):
text = text.replace(f'%%CODEBLOCK{i}%%', block)
return text

5
pr/workflows/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode
from .workflow_engine import WorkflowEngine
from .workflow_storage import WorkflowStorage
__all__ = ['Workflow', 'WorkflowStep', 'ExecutionMode', 'WorkflowEngine', 'WorkflowStorage']

View File

@ -0,0 +1,91 @@
from enum import Enum
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
class ExecutionMode(Enum):
SEQUENTIAL = "sequential"
PARALLEL = "parallel"
CONDITIONAL = "conditional"
@dataclass
class WorkflowStep:
tool_name: str
arguments: Dict[str, Any]
step_id: str
condition: Optional[str] = None
on_success: Optional[List[str]] = None
on_failure: Optional[List[str]] = None
retry_count: int = 0
timeout_seconds: int = 300
def to_dict(self) -> Dict[str, Any]:
return {
'tool_name': self.tool_name,
'arguments': self.arguments,
'step_id': self.step_id,
'condition': self.condition,
'on_success': self.on_success,
'on_failure': self.on_failure,
'retry_count': self.retry_count,
'timeout_seconds': self.timeout_seconds
}
@staticmethod
def from_dict(data: Dict[str, Any]) -> 'WorkflowStep':
return WorkflowStep(
tool_name=data['tool_name'],
arguments=data['arguments'],
step_id=data['step_id'],
condition=data.get('condition'),
on_success=data.get('on_success'),
on_failure=data.get('on_failure'),
retry_count=data.get('retry_count', 0),
timeout_seconds=data.get('timeout_seconds', 300)
)
@dataclass
class Workflow:
name: str
description: str
steps: List[WorkflowStep]
execution_mode: ExecutionMode = ExecutionMode.SEQUENTIAL
variables: Dict[str, Any] = field(default_factory=dict)
tags: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
'name': self.name,
'description': self.description,
'steps': [step.to_dict() for step in self.steps],
'execution_mode': self.execution_mode.value,
'variables': self.variables,
'tags': self.tags
}
@staticmethod
def from_dict(data: Dict[str, Any]) -> 'Workflow':
return Workflow(
name=data['name'],
description=data['description'],
steps=[WorkflowStep.from_dict(step) for step in data['steps']],
execution_mode=ExecutionMode(data.get('execution_mode', 'sequential')),
variables=data.get('variables', {}),
tags=data.get('tags', [])
)
def add_step(self, step: WorkflowStep):
self.steps.append(step)
def get_step(self, step_id: str) -> Optional[WorkflowStep]:
for step in self.steps:
if step.step_id == step_id:
return step
return None
def get_initial_steps(self) -> List[WorkflowStep]:
if self.execution_mode == ExecutionMode.SEQUENTIAL:
return [self.steps[0]] if self.steps else []
elif self.execution_mode == ExecutionMode.PARALLEL:
return self.steps
else:
return [step for step in self.steps if not step.condition]

View File

@ -0,0 +1,192 @@
import time
import re
from typing import Dict, Any, List, Callable, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from .workflow_definition import Workflow, WorkflowStep, ExecutionMode
class WorkflowExecutionContext:
def __init__(self):
self.variables: Dict[str, Any] = {}
self.step_results: Dict[str, Any] = {}
self.execution_log: List[Dict[str, Any]] = []
def set_variable(self, name: str, value: Any):
self.variables[name] = value
def get_variable(self, name: str, default: Any = None) -> Any:
return self.variables.get(name, default)
def set_step_result(self, step_id: str, result: Any):
self.step_results[step_id] = result
def get_step_result(self, step_id: str) -> Any:
return self.step_results.get(step_id)
def log_event(self, event_type: str, step_id: str, details: Dict[str, Any]):
self.execution_log.append({
'timestamp': time.time(),
'event_type': event_type,
'step_id': step_id,
'details': details
})
class WorkflowEngine:
def __init__(self, tool_executor: Callable, max_workers: int = 5):
self.tool_executor = tool_executor
self.max_workers = max_workers
def _evaluate_condition(self, condition: str, context: WorkflowExecutionContext) -> bool:
if not condition:
return True
try:
safe_locals = {
'variables': context.variables,
'results': context.step_results
}
return eval(condition, {"__builtins__": {}}, safe_locals)
except Exception:
return False
def _substitute_variables(self, arguments: Dict[str, Any], context: WorkflowExecutionContext) -> Dict[str, Any]:
substituted = {}
for key, value in arguments.items():
if isinstance(value, str):
pattern = r'\$\{([^}]+)\}'
matches = re.findall(pattern, value)
for match in matches:
if match.startswith('step.'):
step_id = match.split('.', 1)[1]
replacement = context.get_step_result(step_id)
if replacement is not None:
value = value.replace(f'${{{match}}}', str(replacement))
elif match.startswith('var.'):
var_name = match.split('.', 1)[1]
replacement = context.get_variable(var_name)
if replacement is not None:
value = value.replace(f'${{{match}}}', str(replacement))
substituted[key] = value
else:
substituted[key] = value
return substituted
def _execute_step(self, step: WorkflowStep, context: WorkflowExecutionContext) -> Dict[str, Any]:
if not self._evaluate_condition(step.condition, context):
context.log_event('skipped', step.step_id, {'reason': 'condition_not_met'})
return {'status': 'skipped', 'step_id': step.step_id}
arguments = self._substitute_variables(step.arguments, context)
start_time = time.time()
retry_attempts = 0
last_error = None
while retry_attempts <= step.retry_count:
try:
context.log_event('executing', step.step_id, {
'tool': step.tool_name,
'arguments': arguments,
'attempt': retry_attempts + 1
})
result = self.tool_executor(step.tool_name, arguments)
execution_time = time.time() - start_time
context.set_step_result(step.step_id, result)
context.log_event('completed', step.step_id, {
'execution_time': execution_time,
'result_size': len(str(result)) if result else 0
})
return {
'status': 'success',
'step_id': step.step_id,
'result': result,
'execution_time': execution_time
}
except Exception as e:
last_error = str(e)
retry_attempts += 1
if retry_attempts <= step.retry_count:
time.sleep(1 * retry_attempts)
context.log_event('failed', step.step_id, {'error': last_error})
return {
'status': 'failed',
'step_id': step.step_id,
'error': last_error,
'execution_time': time.time() - start_time
}
def _get_next_steps(self, completed_step: WorkflowStep, result: Dict[str, Any],
workflow: Workflow) -> List[WorkflowStep]:
next_steps = []
if result['status'] == 'success' and completed_step.on_success:
for step_id in completed_step.on_success:
step = workflow.get_step(step_id)
if step:
next_steps.append(step)
elif result['status'] == 'failed' and completed_step.on_failure:
for step_id in completed_step.on_failure:
step = workflow.get_step(step_id)
if step:
next_steps.append(step)
elif workflow.execution_mode == ExecutionMode.SEQUENTIAL:
current_index = workflow.steps.index(completed_step)
if current_index + 1 < len(workflow.steps):
next_steps.append(workflow.steps[current_index + 1])
return next_steps
def execute_workflow(self, workflow: Workflow, initial_variables: Optional[Dict[str, Any]] = None) -> WorkflowExecutionContext:
context = WorkflowExecutionContext()
if initial_variables:
context.variables.update(initial_variables)
if workflow.variables:
context.variables.update(workflow.variables)
context.log_event('workflow_started', 'workflow', {'name': workflow.name})
if workflow.execution_mode == ExecutionMode.PARALLEL:
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {
executor.submit(self._execute_step, step, context): step
for step in workflow.steps
}
for future in as_completed(futures):
step = futures[future]
try:
result = future.result()
context.log_event('step_completed', step.step_id, result)
except Exception as e:
context.log_event('step_failed', step.step_id, {'error': str(e)})
else:
pending_steps = workflow.get_initial_steps()
executed_step_ids = set()
while pending_steps:
step = pending_steps.pop(0)
if step.step_id in executed_step_ids:
continue
result = self._execute_step(step, context)
executed_step_ids.add(step.step_id)
next_steps = self._get_next_steps(step, result, workflow)
pending_steps.extend(next_steps)
context.log_event('workflow_completed', 'workflow', {
'total_steps': len(context.step_results),
'executed_steps': list(context.step_results.keys())
})
return context

View File

@ -0,0 +1,214 @@
import json
import sqlite3
import time
from typing import List, Optional
from .workflow_definition import Workflow
class WorkflowStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self._initialize_storage()
def _initialize_storage(self):
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS workflows (
workflow_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT,
workflow_data TEXT NOT NULL,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
execution_count INTEGER DEFAULT 0,
last_execution_at INTEGER,
tags TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS workflow_executions (
execution_id TEXT PRIMARY KEY,
workflow_id TEXT NOT NULL,
started_at INTEGER NOT NULL,
completed_at INTEGER,
status TEXT NOT NULL,
execution_log TEXT,
variables TEXT,
step_results TEXT,
FOREIGN KEY (workflow_id) REFERENCES workflows(workflow_id)
)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_workflow_name ON workflows(name)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_execution_workflow ON workflow_executions(workflow_id)
''')
cursor.execute('''
CREATE INDEX IF NOT EXISTS idx_execution_started ON workflow_executions(started_at)
''')
conn.commit()
conn.close()
def save_workflow(self, workflow: Workflow) -> str:
import hashlib
workflow_data = json.dumps(workflow.to_dict())
workflow_id = hashlib.sha256(workflow.name.encode()).hexdigest()[:16]
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
current_time = int(time.time())
tags_json = json.dumps(workflow.tags)
cursor.execute('''
INSERT OR REPLACE INTO workflows
(workflow_id, name, description, workflow_data, created_at, updated_at, tags)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (workflow_id, workflow.name, workflow.description, workflow_data,
current_time, current_time, tags_json))
conn.commit()
conn.close()
return workflow_id
def load_workflow(self, workflow_id: str) -> Optional[Workflow]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT workflow_data FROM workflows WHERE workflow_id = ?', (workflow_id,))
row = cursor.fetchone()
conn.close()
if row:
workflow_dict = json.loads(row[0])
return Workflow.from_dict(workflow_dict)
return None
def load_workflow_by_name(self, name: str) -> Optional[Workflow]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('SELECT workflow_data FROM workflows WHERE name = ?', (name,))
row = cursor.fetchone()
conn.close()
if row:
workflow_dict = json.loads(row[0])
return Workflow.from_dict(workflow_dict)
return None
def list_workflows(self, tag: Optional[str] = None) -> List[dict]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
if tag:
cursor.execute('''
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
FROM workflows
WHERE tags LIKE ?
ORDER BY name
''', (f'%"{tag}"%',))
else:
cursor.execute('''
SELECT workflow_id, name, description, execution_count, last_execution_at, tags
FROM workflows
ORDER BY name
''')
workflows = []
for row in cursor.fetchall():
workflows.append({
'workflow_id': row[0],
'name': row[1],
'description': row[2],
'execution_count': row[3],
'last_execution_at': row[4],
'tags': json.loads(row[5]) if row[5] else []
})
conn.close()
return workflows
def delete_workflow(self, workflow_id: str) -> bool:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('DELETE FROM workflows WHERE workflow_id = ?', (workflow_id,))
deleted = cursor.rowcount > 0
cursor.execute('DELETE FROM workflow_executions WHERE workflow_id = ?', (workflow_id,))
conn.commit()
conn.close()
return deleted
def save_execution(self, workflow_id: str, execution_context: 'WorkflowExecutionContext') -> str:
import hashlib
import uuid
execution_id = str(uuid.uuid4())[:16]
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
started_at = int(execution_context.execution_log[0]['timestamp']) if execution_context.execution_log else int(time.time())
completed_at = int(time.time())
cursor.execute('''
INSERT INTO workflow_executions
(execution_id, workflow_id, started_at, completed_at, status, execution_log, variables, step_results)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
execution_id,
workflow_id,
started_at,
completed_at,
'completed',
json.dumps(execution_context.execution_log),
json.dumps(execution_context.variables),
json.dumps(execution_context.step_results)
))
cursor.execute('''
UPDATE workflows
SET execution_count = execution_count + 1,
last_execution_at = ?
WHERE workflow_id = ?
''', (completed_at, workflow_id))
conn.commit()
conn.close()
return execution_id
def get_execution_history(self, workflow_id: str, limit: int = 10) -> List[dict]:
conn = sqlite3.connect(self.db_path, check_same_thread=False)
cursor = conn.cursor()
cursor.execute('''
SELECT execution_id, started_at, completed_at, status
FROM workflow_executions
WHERE workflow_id = ?
ORDER BY started_at DESC
LIMIT ?
''', (workflow_id, limit))
executions = []
for row in cursor.fetchall():
executions.append({
'execution_id': row[0],
'started_at': row[1],
'completed_at': row[2],
'status': row[3]
})
conn.close()
return executions

115
pyproject.toml Normal file
View File

@ -0,0 +1,115 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "pr-assistant"
version = "1.0.0"
description = "Professional CLI AI assistant with autonomous execution capabilities"
readme = "README.md"
requires-python = ">=3.8"
license = {text = "MIT"}
keywords = ["ai", "assistant", "cli", "automation", "openrouter", "autonomous"]
authors = [
{name = "retoor", email = "retoor@example.com"}
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"black>=23.0.0",
"flake8>=6.0.0",
"mypy>=1.0.0",
"pre-commit>=3.0.0",
]
[project.scripts]
pr = "pr.__main__:main"
rp = "pr.__main__:main"
rpe = "pr.editor:main"
[project.urls]
Homepage = "https://github.com/retoor/pr-assistant"
Documentation = "https://github.com/retoor/pr-assistant#readme"
Repository = "https://github.com/retoor/pr-assistant"
"Bug Tracker" = "https://github.com/retoor/pr-assistant/issues"
[tool.setuptools.packages.find]
where = ["."]
include = ["pr*"]
exclude = ["tests*"]
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --cov=pr --cov-report=term-missing --cov-report=html"
[tool.black]
line-length = 100
target-version = ['py38', 'py39', 'py310', 'py311']
include = '\.pyi?$'
extend-exclude = '''
/(
__pycache__
| \.git
| \.mypy_cache
| \.pytest_cache
| \.venv
| build
| dist
)/
'''
[tool.mypy]
python_version = "3.8"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = false
disallow_incomplete_defs = false
check_untyped_defs = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
[tool.coverage.run]
source = ["pr"]
omit = ["*/tests/*", "*/__pycache__/*"]
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise AssertionError",
"raise NotImplementedError",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
[tool.isort]
profile = "black"
line_length = 100
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
[tool.bandit]
exclude_dirs = ["tests", "venv", ".venv"]
skips = ["B101"]

7
rp.py Executable file
View File

@ -0,0 +1,7 @@
#!/usr/bin/env python3
import sys
from pr.__main__ import main
if __name__ == '__main__':
main()

0
tests/__init__.py Normal file
View File

53
tests/conftest.py Normal file
View File

@ -0,0 +1,53 @@
import pytest
import os
import tempfile
from unittest.mock import MagicMock
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdir:
yield tmpdir
@pytest.fixture
def mock_api_response():
return {
'choices': [
{
'message': {
'role': 'assistant',
'content': 'Test response'
}
}
],
'usage': {
'prompt_tokens': 10,
'completion_tokens': 5,
'total_tokens': 15
}
}
@pytest.fixture
def mock_args():
args = MagicMock()
args.message = None
args.model = None
args.api_url = None
args.model_list_url = None
args.interactive = False
args.verbose = False
args.no_syntax = False
args.include_env = False
args.context = None
args.api_mode = False
return args
@pytest.fixture
def sample_context_file(temp_dir):
context_path = os.path.join(temp_dir, '.rcontext.txt')
with open(context_path, 'w') as f:
f.write('Sample context content\n')
return context_path

127
tests/test_agents.py Normal file
View File

@ -0,0 +1,127 @@
import pytest
import time
from pr.agents.agent_roles import AgentRole, get_agent_role, list_agent_roles
from pr.agents.agent_manager import AgentManager, AgentInstance
from pr.agents.agent_communication import AgentCommunicationBus, AgentMessage, MessageType
def test_get_agent_role():
role = get_agent_role('coding')
assert isinstance(role, AgentRole)
assert role.name == 'coding'
def test_list_agent_roles():
roles = list_agent_roles()
assert isinstance(roles, dict)
assert len(roles) > 0
assert 'coding' in roles
def test_agent_role():
role = AgentRole(name='test', description='test', system_prompt='test', allowed_tools=set(), specialization_areas=[])
assert role.name == 'test'
def test_agent_instance():
role = get_agent_role('coding')
instance = AgentInstance(agent_id='test', role=role)
assert instance.agent_id == 'test'
assert instance.role == role
def test_agent_manager_init():
mgr = AgentManager(':memory:', None)
assert mgr is not None
def test_agent_manager_create_agent():
mgr = AgentManager(':memory:', None)
agent = mgr.create_agent('coding', 'test_agent')
assert agent is not None
def test_agent_manager_get_agent():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'test_agent')
agent = mgr.get_agent('test_agent')
assert isinstance(agent, AgentInstance)
def test_agent_manager_remove_agent():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'test_agent')
mgr.remove_agent('test_agent')
agent = mgr.get_agent('test_agent')
assert agent is None
def test_agent_manager_send_agent_message():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'a')
mgr.create_agent('coding', 'b')
mgr.send_agent_message('a', 'b', 'test')
assert True
def test_agent_manager_get_agent_messages():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'test')
messages = mgr.get_agent_messages('test')
assert isinstance(messages, list)
def test_agent_manager_get_session_summary():
mgr = AgentManager(':memory:', None)
summary = mgr.get_session_summary()
assert isinstance(summary, str)
def test_agent_manager_collaborate_agents():
mgr = AgentManager(':memory:', None)
result = mgr.collaborate_agents('orchestrator', 'task', ['coding', 'research'])
assert result is not None
def test_agent_manager_execute_agent_task():
mgr = AgentManager(':memory:', None)
mgr.create_agent('coding', 'test')
result = mgr.execute_agent_task('test', 'task')
assert result is not None
def test_agent_manager_clear_session():
mgr = AgentManager(':memory:', None)
mgr.clear_session()
assert True
def test_agent_message():
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
assert msg.from_agent == 'a'
def test_agent_message_to_dict():
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
d = msg.to_dict()
assert isinstance(d, dict)
def test_agent_message_from_dict():
d = {'from_agent': 'a', 'to_agent': 'b', 'message_type': 'request', 'content': 'test', 'metadata': {}, 'timestamp': 1.0, 'message_id': 'id'}
msg = AgentMessage.from_dict(d)
assert isinstance(msg, AgentMessage)
def test_agent_communication_bus_init():
bus = AgentCommunicationBus(':memory:')
assert bus is not None
def test_agent_communication_bus_send_message():
bus = AgentCommunicationBus(':memory:')
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
bus.send_message(msg)
assert True
def test_agent_communication_bus_receive_messages():
bus = AgentCommunicationBus(':memory:')
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
bus.send_message(msg)
messages = bus.receive_messages('b')
assert len(messages) == 1
def test_agent_communication_bus_get_conversation_history():
bus = AgentCommunicationBus(':memory:')
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
bus.send_message(msg)
history = bus.get_conversation_history('a', 'b')
assert len(history) == 1
def test_agent_communication_bus_mark_as_read():
bus = AgentCommunicationBus(':memory:')
msg = AgentMessage(from_agent='a', to_agent='b', message_type=MessageType.REQUEST, content='test', metadata={}, timestamp=1.0, message_id='id')
bus.send_message(msg)
bus.mark_as_read(msg.message_id)
assert True

31
tests/test_config.py Normal file
View File

@ -0,0 +1,31 @@
import pytest
from pr import config
class TestConfig:
def test_default_model_exists(self):
assert hasattr(config, 'DEFAULT_MODEL')
assert isinstance(config.DEFAULT_MODEL, str)
assert len(config.DEFAULT_MODEL) > 0
def test_api_url_exists(self):
assert hasattr(config, 'DEFAULT_API_URL')
assert config.DEFAULT_API_URL.startswith('http')
def test_file_paths_exist(self):
assert hasattr(config, 'DB_PATH')
assert hasattr(config, 'LOG_FILE')
assert hasattr(config, 'HISTORY_FILE')
def test_autonomous_config(self):
assert hasattr(config, 'MAX_AUTONOMOUS_ITERATIONS')
assert config.MAX_AUTONOMOUS_ITERATIONS > 0
assert hasattr(config, 'CONTEXT_COMPRESSION_THRESHOLD')
assert config.CONTEXT_COMPRESSION_THRESHOLD > 0
def test_language_keywords(self):
assert hasattr(config, 'LANGUAGE_KEYWORDS')
assert 'python' in config.LANGUAGE_KEYWORDS
assert isinstance(config.LANGUAGE_KEYWORDS['python'], list)

35
tests/test_context.py Normal file
View File

@ -0,0 +1,35 @@
import pytest
from pr.core.context import should_compress_context, compress_context
from pr.config import RECENT_MESSAGES_TO_KEEP
class TestContextManagement:
def test_should_compress_context_below_threshold(self):
messages = [{'role': 'user', 'content': 'test'}] * 10
assert should_compress_context(messages) is False
def test_should_compress_context_above_threshold(self):
messages = [{'role': 'user', 'content': 'test'}] * 35
assert should_compress_context(messages) is True
def test_compress_context_preserves_system_message(self):
messages = [
{'role': 'system', 'content': 'System prompt'},
{'role': 'user', 'content': 'Hello'},
{'role': 'assistant', 'content': 'Hi'},
] * 40 # Ensure compression
compressed = compress_context(messages)
assert compressed[0]['role'] == 'system'
assert 'System prompt' in compressed[0]['content']
def test_compress_context_keeps_recent_messages(self):
messages = [{'role': 'user', 'content': f'msg{i}'} for i in range(40)]
compressed = compress_context(messages)
# Should keep recent messages
recent = compressed[-RECENT_MESSAGES_TO_KEEP:]
assert len(recent) == RECENT_MESSAGES_TO_KEEP
# Check that the messages are the most recent ones
for i, msg in enumerate(recent):
expected_index = 40 - RECENT_MESSAGES_TO_KEEP + i
assert msg['content'] == f'msg{expected_index}'

118
tests/test_tools.py Normal file
View File

@ -0,0 +1,118 @@
import pytest
import os
import tempfile
from pr.tools.filesystem import read_file, write_file, list_directory, search_replace
from pr.tools.patch import apply_patch, create_diff
from pr.tools.base import get_tools_definition
class TestFilesystemTools:
def test_write_and_read_file(self, temp_dir):
filepath = os.path.join(temp_dir, 'test.txt')
content = 'Hello, World!'
write_result = write_file(filepath, content)
assert write_result['status'] == 'success'
read_result = read_file(filepath)
assert read_result['status'] == 'success'
assert content in read_result['content']
def test_read_nonexistent_file(self):
result = read_file('/nonexistent/path/file.txt')
assert result['status'] == 'error'
def test_list_directory(self, temp_dir):
test_file = os.path.join(temp_dir, 'testfile.txt')
with open(test_file, 'w') as f:
f.write('test')
result = list_directory(temp_dir)
assert result['status'] == 'success'
assert any(item['name'] == 'testfile.txt' for item in result['items'])
def test_search_replace(self, temp_dir):
filepath = os.path.join(temp_dir, 'test.txt')
content = 'Hello, World!'
with open(filepath, 'w') as f:
f.write(content)
result = search_replace(filepath, 'World', 'Universe')
assert result['status'] == 'success'
read_result = read_file(filepath)
assert 'Hello, Universe!' in read_result['content']
class TestPatchTools:
def test_create_diff(self, temp_dir):
file1 = os.path.join(temp_dir, 'file1.txt')
file2 = os.path.join(temp_dir, 'file2.txt')
with open(file1, 'w') as f:
f.write('line1\nline2\nline3\n')
with open(file2, 'w') as f:
f.write('line1\nline2 modified\nline3\n')
result = create_diff(file1, file2)
assert result['status'] == 'success'
assert 'line2' in result['diff']
assert 'line2 modified' in result['diff']
def test_apply_patch(self, temp_dir):
filepath = os.path.join(temp_dir, 'file.txt')
with open(filepath, 'w') as f:
f.write('line1\nline2\nline3\n')
# Create a simple patch
patch_content = """--- a/file.txt
+++ b/file.txt
@@ -1,3 +1,3 @@
line1
-line2
+line2 modified
line3
"""
result = apply_patch(filepath, patch_content)
assert result['status'] == 'success'
read_result = read_file(filepath)
assert 'line2 modified' in read_result['content']
class TestToolDefinitions:
def test_get_tools_definition_returns_list(self):
tools = get_tools_definition()
assert isinstance(tools, list)
assert len(tools) > 0
def test_all_tools_have_required_fields(self):
tools = get_tools_definition()
for tool in tools:
assert 'type' in tool
assert tool['type'] == 'function'
assert 'function' in tool
func = tool['function']
assert 'name' in func
assert 'description' in func
assert 'parameters' in func
def test_filesystem_tools_present(self):
tools = get_tools_definition()
tool_names = [t['function']['name'] for t in tools]
assert 'read_file' in tool_names
assert 'write_file' in tool_names
assert 'list_directory' in tool_names
assert 'search_replace' in tool_names
def test_patch_tools_present(self):
tools = get_tools_definition()
tool_names = [t['function']['name'] for t in tools]
assert 'apply_patch' in tool_names
assert 'create_diff' in tool_names